06. Transfer Learning with Tensorflow Part 3: Scaling up (Food Vision mini)¶

We've seen how powerful transfer learning is in the part 1 and 2 notebooks. They were all small modelling experiments, so it's time to step up a bit.

It's common practice to practice ML and deep learning, by getting a model to work on a small subset of data, before scaling to a larger/full dataset.

We will scale up from 10 categories in Food101, to everything else.

Our goal is to beat the Food101 paper with 10% of data.

alt text

ML practitioners are serial experimenters. Start small, get a model working, see how it goes, and then gradually scale up to your end goal.

What we're going to cover¶

We're gonna go through the following:

  • Downloading and preparing 10% of Food101 data
  • Training a feature extraction transfer learning model on 10% of the Food101 training data
  • Fine-tuning our feature extraction model
  • Saving and loaded our trained model
  • Evaluating the performance of our Food Vision model
    • Find the model's worst performing predictions
  • Making predictions with our Food Vision model on custom images of food
In [3]:
# are we using gpu?
!nvidia-smi
Mon Oct 27 19:42:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce GTX 1060 6GB  WDDM  |   00000000:0A:00.0  On |                  N/A |
|  0%   54C    P0             30W /  120W |    1661MiB /   6144MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1764    C+G   ...ejd91yc\AdobeNotificationClient.exe      N/A      |
|    0   N/A  N/A      1844    C+G   ...GeForce Experience\NVIDIA Share.exe      N/A      |
|    0   N/A  N/A      3428      C   ...elchupacabra\App\Bandicam\bdcam.exe      N/A      |
|    0   N/A  N/A      7528    C+G   ...\Adobe Photoshop 2021\Photoshop.exe      N/A      |
|    0   N/A  N/A     10652    C+G   ...GeForce Experience\NVIDIA Share.exe      N/A      |
|    0   N/A  N/A     11660    C+G   ...oogle\Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A     14488    C+G   ... Files\Elgato\WaveLink\WaveLink.exe      N/A      |
|    0   N/A  N/A     17812    C+G   ...remium\win64\bin\HarmonyPremium.exe      N/A      |
|    0   N/A  N/A     18036    C+G   ...cal\Microsoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A     18352    C+G   X:\Mozilla_Thunderbird\thunderbird.exe      N/A      |
|    0   N/A  N/A     19188    C+G   ...soft Office\root\Office16\EXCEL.EXE      N/A      |
|    0   N/A  N/A     19676    C+G   ...2.0_x64__cv1g1gvanyjgm\WhatsApp.exe      N/A      |
|    0   N/A  N/A     20236    C+G   ...dobe\Adobe Animate 2021\Animate.exe      N/A      |
|    0   N/A  N/A     20584    C+G   ...cal\Microsoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A     21080    C+G   ...\cef\cef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A     25892    C+G   ...oogle\Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A     36744    C+G   ...t.LockApp_cw5n1h2txyewy\LockApp.exe      N/A      |
|    0   N/A  N/A     37408    C+G   ...al\Discord\app-1.0.9205\Discord.exe      N/A      |
|    0   N/A  N/A     38768    C+G   ...2txyewy\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A     43584    C+G   ....Search_cw5n1h2txyewy\SearchApp.exe      N/A      |
|    0   N/A  N/A     47300    C+G   ...ekyb3d8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A     50096    C+G   ...-ins\Spaces\Adobe Spaces Helper.exe      N/A      |
|    0   N/A  N/A     52420    C+G   ...CBS_cw5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A     52568    C+G   ...\DAUM\PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A     53760    C+G   ...5n1h2txyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A     56388    C+G   ...05.0_x64__8wekyb3d8bbwe\Cortana.exe      N/A      |
|    0   N/A  N/A     64072    C+G   ...crosoft\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A     64204    C+G   ...1.0_x64__8wekyb3d8bbwe\Video.UI.exe      N/A      |
|    0   N/A  N/A     65052    C+G   ...on\141.0.3537.85\msedgewebview2.exe      N/A      |
|    0   N/A  N/A     66376    C+G   ...remium\win64\bin\HarmonyPremium.exe      N/A      |
|    0   N/A  N/A     66948    C+G   C:\Windows\explorer.exe                     N/A      |
|    0   N/A  N/A     69428    C+G   ....Search_cw5n1h2txyewy\SearchApp.exe      N/A      |
|    0   N/A  N/A     71684    C+G   X:\Microsoft VS Code\Code.exe               N/A      |
+-----------------------------------------------------------------------------------------+
In [27]:
import datetime
print(f'Notebook last run (end-to-end): {datetime.datetime.now()}')
Notebook last run (end-to-end): 2025-11-03 23:04:30.566595

Creating helper functions¶

We have a file with all useful functions that can come in handy for us. It'll be tedious to rewrite them all, and better to just import the .py file.

In [28]:
# get helper_functions.py script from course Github
!curl -O https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py

# import helper functions we're going to use
import sys, os
sys.path.append(os.getcwd())
import sklearn


from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir, compare_historys
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 10246  100 10246    0     0  23537      0 --:--:-- --:--:-- --:--:-- 23717

101 Food Classes: Working with less data¶

So far, our previous experiments in transfer learning has worked quite well in 10 classes of food data. So it's time to make the jump for the full 101 classes.

The original Food101 has 1000 images per class, 750 for train, 250 for test, totalling 101,000 unique pics.

We can use the full dataset. But in the spirit of experimentation, we'll only use 10% of training data, and see how it does.

This means only 75 images per 101 classes for training, while keeping the original 250 test data.

Downloading and preprocessing data¶

We'll download a subset of Food101 dataset, which will come as a zip file. We will use unzip_data() function to unzip it.

In [29]:
# download data from google dtorage
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip 

unzip_data('101_food_classes_10_percent.zip')

train_dir = '101_food_classes_10_percent/train/'
test_dir = '101_food_classes_10_percent/test/'
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0 1550M    0  7540    0     0   9748      0 46:19:03 --:--:-- 46:19:03  9779
  0 1550M    0 1278k    0     0   731k      0  0:36:11  0:00:01  0:36:10  731k
  0 1550M    0 12.4M    0     0  4656k      0  0:05:40  0:00:02  0:05:38 4660k
  1 1550M    1 24.2M    0     0  6617k      0  0:03:59  0:00:03  0:03:56 6620k
  2 1550M    2 36.4M    0     0  7875k      0  0:03:21  0:00:04  0:03:17 7879k
  3 1550M    3 48.7M    0     0  8658k      0  0:03:03  0:00:05  0:02:58 9998k
  3 1550M    3 59.2M    0     0  8985k      0  0:02:56  0:00:06  0:02:50 11.5M
  4 1550M    4 69.8M    0     0  9235k      0  0:02:51  0:00:07  0:02:44 11.4M
  5 1550M    5 81.0M    0     0  9493k      0  0:02:47  0:00:08  0:02:39 11.3M
  6 1550M    6 93.1M    0     0  9778k      0  0:02:42  0:00:09  0:02:33 11.3M
  6 1550M    6  105M    0     0   9.7M      0  0:02:38  0:00:10  0:02:28 11.3M
  7 1550M    7  117M    0     0   9.9M      0  0:02:35  0:00:11  0:02:24 11.5M
  8 1550M    8  128M    0     0  10.0M      0  0:02:34  0:00:12  0:02:22 11.6M
  8 1550M    8  139M    0     0  10.1M      0  0:02:32  0:00:13  0:02:19 11.6M
  9 1550M    9  151M    0     0  10.3M      0  0:02:30  0:00:14  0:02:16 11.7M
 10 1550M   10  163M    0     0  10.3M      0  0:02:29  0:00:15  0:02:14 11.6M
 11 1550M   11  175M    0     0  10.4M      0  0:02:28  0:00:16  0:02:12 11.6M
 11 1550M   11  185M    0     0  10.4M      0  0:02:28  0:00:17  0:02:11 11.5M
 12 1550M   12  197M    0     0  10.5M      0  0:02:27  0:00:18  0:02:09 11.6M
 13 1550M   13  209M    0     0  10.6M      0  0:02:25  0:00:19  0:02:06 11.5M
 14 1550M   14  221M    0     0  10.6M      0  0:02:25  0:00:20  0:02:05 11.6M
 15 1550M   15  233M    0     0  10.7M      0  0:02:24  0:00:21  0:02:03 11.7M
 15 1550M   15  244M    0     0  10.7M      0  0:02:24  0:00:22  0:02:02 11.8M
 16 1550M   16  257M    0     0  10.8M      0  0:02:23  0:00:23  0:02:00 11.8M
 17 1550M   17  269M    0     0  10.8M      0  0:02:22  0:00:24  0:01:58 11.8M
 18 1550M   18  280M    0     0  10.8M      0  0:02:22  0:00:25  0:01:57 11.7M
 18 1550M   18  290M    0     0  10.8M      0  0:02:22  0:00:26  0:01:56 11.3M
 19 1550M   19  301M    0     0  10.8M      0  0:02:22  0:00:27  0:01:55 11.3M
 20 1550M   20  313M    0     0  10.8M      0  0:02:22  0:00:28  0:01:54 11.2M
 20 1550M   20  325M    0     0  10.9M      0  0:02:21  0:00:29  0:01:52 11.1M
 21 1550M   21  337M    0     0  10.9M      0  0:02:21  0:00:30  0:01:51 11.3M
 22 1550M   22  348M    0     0  10.9M      0  0:02:21  0:00:31  0:01:50 11.6M
 23 1550M   23  360M    0     0  10.9M      0  0:02:20  0:00:32  0:01:48 11.7M
 23 1550M   23  370M    0     0  10.9M      0  0:02:21  0:00:33  0:01:48 11.4M
 24 1550M   24  382M    0     0  10.9M      0  0:02:20  0:00:34  0:01:46 11.4M
 25 1550M   25  393M    0     0  11.0M      0  0:02:20  0:00:35  0:01:45 11.2M
 26 1550M   26  404M    0     0  11.0M      0  0:02:20  0:00:36  0:01:44 11.1M
 26 1550M   26  416M    0     0  11.0M      0  0:02:20  0:00:37  0:01:43 11.2M
 27 1550M   27  426M    0     0  11.0M      0  0:02:20  0:00:38  0:01:42 11.1M
 28 1550M   28  438M    0     0  11.0M      0  0:02:20  0:00:39  0:01:41 11.2M
 29 1550M   29  450M    0     0  11.0M      0  0:02:20  0:00:40  0:01:40 11.3M
 29 1550M   29  461M    0     0  11.0M      0  0:02:20  0:00:41  0:01:39 11.4M
 30 1550M   30  474M    0     0  11.0M      0  0:02:19  0:00:42  0:01:37 11.5M
 31 1550M   31  484M    0     0  11.0M      0  0:02:19  0:00:43  0:01:36 11.5M
 31 1550M   31  495M    0     0  11.0M      0  0:02:19  0:00:44  0:01:35 11.4M
 32 1550M   32  507M    0     0  11.0M      0  0:02:19  0:00:45  0:01:34 11.4M
 33 1550M   33  519M    0     0  11.1M      0  0:02:19  0:00:46  0:01:33 11.4M
 34 1550M   34  530M    0     0  11.1M      0  0:02:19  0:00:47  0:01:32 11.3M
 34 1550M   34  541M    0     0  11.1M      0  0:02:19  0:00:48  0:01:31 11.4M
 35 1550M   35  552M    0     0  11.0M      0  0:02:19  0:00:49  0:01:30 11.2M
 36 1550M   36  563M    0     0  11.1M      0  0:02:19  0:00:50  0:01:29 11.2M
 37 1550M   37  575M    0     0  11.1M      0  0:02:19  0:00:51  0:01:28 11.2M
 37 1550M   37  587M    0     0  11.1M      0  0:02:19  0:00:52  0:01:27 11.2M
 38 1550M   38  598M    0     0  11.1M      0  0:02:19  0:00:53  0:01:26 11.4M
 39 1550M   39  609M    0     0  11.1M      0  0:02:19  0:00:54  0:01:25 11.4M
 40 1550M   40  621M    0     0  11.1M      0  0:02:19  0:00:55  0:01:24 11.4M
 40 1550M   40  632M    0     0  11.1M      0  0:02:18  0:00:56  0:01:22 11.4M
 41 1550M   41  644M    0     0  11.1M      0  0:02:18  0:00:57  0:01:21 11.4M
 42 1550M   42  656M    0     0  11.1M      0  0:02:18  0:00:58  0:01:20 11.5M
 43 1550M   43  666M    0     0  11.1M      0  0:02:18  0:00:59  0:01:19 11.4M
 43 1550M   43  678M    0     0  11.1M      0  0:02:18  0:01:00  0:01:18 11.4M
 44 1550M   44  689M    0     0  11.1M      0  0:02:18  0:01:01  0:01:17 11.3M
 45 1550M   45  701M    0     0  11.1M      0  0:02:18  0:01:02  0:01:16 11.3M
 45 1550M   45  713M    0     0  11.1M      0  0:02:18  0:01:03  0:01:15 11.3M
 46 1550M   46  724M    0     0  11.1M      0  0:02:18  0:01:04  0:01:14 11.5M
 47 1550M   47  735M    0     0  11.1M      0  0:02:18  0:01:05  0:01:13 11.4M
 48 1550M   48  747M    0     0  11.1M      0  0:02:18  0:01:06  0:01:12 11.4M
 48 1550M   48  758M    0     0  11.1M      0  0:02:18  0:01:07  0:01:11 11.3M
 49 1550M   49  769M    0     0  11.1M      0  0:02:18  0:01:08  0:01:10 11.2M
 50 1550M   50  781M    0     0  11.2M      0  0:02:18  0:01:09  0:01:09 11.3M
 51 1550M   51  791M    0     0  11.1M      0  0:02:18  0:01:10  0:01:08 11.2M
 51 1550M   51  803M    0     0  11.2M      0  0:02:18  0:01:11  0:01:07 11.3M
 52 1550M   52  815M    0     0  11.2M      0  0:02:18  0:01:12  0:01:06 11.4M
 53 1550M   53  827M    0     0  11.2M      0  0:02:18  0:01:13  0:01:05 11.5M
 54 1550M   54  839M    0     0  11.2M      0  0:02:18  0:01:14  0:01:04 11.6M
 54 1550M   54  850M    0     0  11.2M      0  0:02:18  0:01:15  0:01:03 11.7M
 55 1550M   55  862M    0     0  11.2M      0  0:02:17  0:01:16  0:01:01 11.7M
 56 1550M   56  874M    0     0  11.2M      0  0:02:17  0:01:17  0:01:00 11.7M
 57 1550M   57  886M    0     0  11.2M      0  0:02:17  0:01:18  0:00:59 11.7M
 57 1550M   57  897M    0     0  11.2M      0  0:02:17  0:01:19  0:00:58 11.6M
 58 1550M   58  908M    0     0  11.2M      0  0:02:17  0:01:20  0:00:57 11.5M
 59 1550M   59  920M    0     0  11.2M      0  0:02:17  0:01:21  0:00:56 11.5M
 60 1550M   60  932M    0     0  11.2M      0  0:02:17  0:01:22  0:00:55 11.5M
 60 1550M   60  943M    0     0  11.2M      0  0:02:17  0:01:23  0:00:54 11.5M
 61 1550M   61  954M    0     0  11.2M      0  0:02:17  0:01:24  0:00:53 11.2M
 62 1550M   62  962M    0     0  11.2M      0  0:02:18  0:01:25  0:00:53 10.7M
 62 1550M   62  970M    0     0  11.1M      0  0:02:18  0:01:26  0:00:52 10.0M
 63 1550M   63  977M    0     0  11.1M      0  0:02:19  0:01:27  0:00:52 9299k
 63 1550M   63  987M    0     0  11.1M      0  0:02:19  0:01:28  0:00:51 9060k
 64 1550M   64  998M    0     0  11.1M      0  0:02:19  0:01:29  0:00:50 9213k
 65 1550M   65 1010M    0     0  11.1M      0  0:02:19  0:01:30  0:00:49 9827k
 65 1550M   65 1019M    0     0  11.1M      0  0:02:19  0:01:31  0:00:48  9.8M
 66 1550M   66 1029M    0     0  11.1M      0  0:02:19  0:01:32  0:00:47 10.4M
 67 1550M   67 1039M    0     0  11.0M      0  0:02:19  0:01:33  0:00:46 10.3M
 67 1550M   67 1050M    0     0  11.0M      0  0:02:19  0:01:34  0:00:45 10.2M
 68 1550M   68 1060M    0     0  11.0M      0  0:02:19  0:01:35  0:00:44 10.1M
 69 1550M   69 1070M    0     0  11.0M      0  0:02:20  0:01:36  0:00:44 10.1M
 69 1550M   69 1080M    0     0  11.0M      0  0:02:20  0:01:37  0:00:43 10.0M
 70 1550M   70 1088M    0     0  11.0M      0  0:02:20  0:01:38  0:00:42  9.8M
 70 1550M   70 1099M    0     0  11.0M      0  0:02:20  0:01:39  0:00:41  9.8M
 71 1550M   71 1109M    0     0  11.0M      0  0:02:20  0:01:40  0:00:40 9927k
 72 1550M   72 1118M    0     0  10.9M      0  0:02:21  0:01:41  0:00:40 9909k
 72 1550M   72 1128M    0     0  10.9M      0  0:02:21  0:01:42  0:00:39 9929k
 73 1550M   73 1139M    0     0  10.9M      0  0:02:21  0:01:43  0:00:38 10.0M
 74 1550M   74 1150M    0     0  10.9M      0  0:02:21  0:01:44  0:00:37 10.2M
 74 1550M   74 1161M    0     0  10.9M      0  0:02:21  0:01:45  0:00:36 10.4M
 75 1550M   75 1172M    0     0  10.9M      0  0:02:21  0:01:46  0:00:35 10.7M
 76 1550M   76 1181M    0     0  10.9M      0  0:02:21  0:01:47  0:00:34 10.5M
 76 1550M   76 1191M    0     0  10.9M      0  0:02:21  0:01:48  0:00:33 10.4M
 77 1550M   77 1202M    0     0  10.9M      0  0:02:21  0:01:49  0:00:32 10.3M
 78 1550M   78 1212M    0     0  10.9M      0  0:02:21  0:01:50  0:00:31 10.1M
 78 1550M   78 1223M    0     0  10.9M      0  0:02:21  0:01:51  0:00:30 10.2M
 79 1550M   79 1232M    0     0  10.9M      0  0:02:21  0:01:52  0:00:29 10.2M
 80 1550M   80 1243M    0     0  10.9M      0  0:02:21  0:01:53  0:00:28 10.4M
 80 1550M   80 1253M    0     0  10.9M      0  0:02:21  0:01:54  0:00:27 10.3M
 81 1550M   81 1265M    0     0  10.9M      0  0:02:21  0:01:55  0:00:26 10.5M
 82 1550M   82 1277M    0     0  10.9M      0  0:02:21  0:01:56  0:00:25 10.7M
 83 1550M   83 1288M    0     0  10.9M      0  0:02:21  0:01:57  0:00:24 11.1M
 83 1550M   83 1298M    0     0  10.9M      0  0:02:21  0:01:58  0:00:23 11.0M
 84 1550M   84 1310M    0     0  10.9M      0  0:02:21  0:01:59  0:00:22 11.3M
 85 1550M   85 1322M    0     0  10.9M      0  0:02:21  0:02:00  0:00:21 11.3M
 86 1550M   86 1334M    0     0  10.9M      0  0:02:21  0:02:01  0:00:20 11.3M
 86 1550M   86 1345M    0     0  10.9M      0  0:02:21  0:02:02  0:00:19 11.5M
 87 1550M   87 1355M    0     0  10.9M      0  0:02:21  0:02:03  0:00:18 11.3M
 88 1550M   88 1365M    0     0  10.9M      0  0:02:21  0:02:04  0:00:17 11.0M
 88 1550M   88 1375M    0     0  10.9M      0  0:02:21  0:02:05  0:00:16 10.7M
 89 1550M   89 1386M    0     0  10.9M      0  0:02:21  0:02:06  0:00:15 10.4M
 90 1550M   90 1397M    0     0  10.9M      0  0:02:21  0:02:07  0:00:14 10.4M
 90 1550M   90 1409M    0     0  10.9M      0  0:02:21  0:02:08  0:00:13 10.7M
 91 1550M   91 1420M    0     0  10.9M      0  0:02:21  0:02:09  0:00:12 10.9M
 92 1550M   92 1430M    0     0  10.9M      0  0:02:21  0:02:10  0:00:11 10.8M
 92 1550M   92 1441M    0     0  10.9M      0  0:02:21  0:02:11  0:00:10 10.9M
 93 1550M   93 1452M    0     0  10.9M      0  0:02:21  0:02:12  0:00:09 10.8M
 94 1550M   94 1463M    0     0  10.9M      0  0:02:21  0:02:13  0:00:08 10.8M
 95 1550M   95 1474M    0     0  10.9M      0  0:02:21  0:02:14  0:00:07 10.8M
 95 1550M   95 1486M    0     0  10.9M      0  0:02:21  0:02:15  0:00:06 11.3M
 96 1550M   96 1496M    0     0  10.9M      0  0:02:21  0:02:16  0:00:05 11.1M
 97 1550M   97 1508M    0     0  10.9M      0  0:02:21  0:02:17  0:00:04 11.2M
 98 1550M   98 1519M    0     0  10.9M      0  0:02:21  0:02:18  0:00:03 11.3M
 98 1550M   98 1532M    0     0  10.9M      0  0:02:21  0:02:19  0:00:02 11.4M
 99 1550M   99 1543M    0     0  10.9M      0  0:02:21  0:02:20  0:00:01 11.4M
100 1550M  100 1550M    0     0  10.9M      0  0:02:21  0:02:21 --:--:-- 11.4M
In [30]:
# How many images/classes are there?
walk_through_dir('101_food_classes_10_percent')
There are 2 directories and 0 images in '101_food_classes_10_percent'.
There are 101 directories and 0 images in '101_food_classes_10_percent\test'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\foie_gras'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\club_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\cheese_plate'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\cup_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\garlic_bread'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\gnocchi'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\ice_cream'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\samosa'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\donuts'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\tuna_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\filet_mignon'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\seaweed_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_toast'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_curry'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\shrimp_and_grits'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\steak'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\cheesecake'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\red_velvet_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\waffles'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\churros'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\gyoza'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\lobster_roll_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\huevos_rancheros'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\breakfast_burrito'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\grilled_cheese_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\spaghetti_bolognese'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\falafel'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\poutine'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\greek_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\beef_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\fried_calamari'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\guacamole'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\ravioli'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\lobster_bisque'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\beet_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\risotto'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\crab_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\strawberry_shortcake'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\edamame'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\ceviche'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\hot_and_sour_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\spring_rolls'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\sashimi'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\paella'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\clam_chowder'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\miso_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\escargots'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\hot_dog'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pulled_pork_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\bruschetta'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\panna_cotta'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\fish_and_chips'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pad_thai'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\tiramisu'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\takoyaki'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\macarons'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\apple_pie'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\cannoli'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\scallops'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\frozen_yogurt'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_quesadilla'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\mussels'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\beef_carpaccio'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\eggs_benedict'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\spaghetti_carbonara'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\omelette'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\sushi'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\chocolate_mousse'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\beignets'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\bibimbap'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\hummus'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pork_chop'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\chicken_wings'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\grilled_salmon'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\chocolate_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\tacos'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\hamburger'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\baby_back_ribs'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pancakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\prime_rib'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pizza'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\nachos'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\macaroni_and_cheese'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\bread_pudding'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\ramen'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\croque_madame'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\lasagna'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\peking_duck'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\deviled_eggs'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_fries'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\dumplings'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\fried_rice'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\french_onion_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\pho'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\caprese_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\oysters'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\baklava'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\creme_brulee'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\carrot_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\onion_rings'.
There are 0 directories and 250 images in '101_food_classes_10_percent\test\caesar_salad'.
There are 101 directories and 0 images in '101_food_classes_10_percent\train'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\foie_gras'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\club_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\cheese_plate'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\cup_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\garlic_bread'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\gnocchi'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\ice_cream'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\samosa'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\donuts'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\tuna_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\filet_mignon'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\seaweed_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_toast'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_curry'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\shrimp_and_grits'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\steak'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\cheesecake'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\red_velvet_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\waffles'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\churros'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\gyoza'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\lobster_roll_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\huevos_rancheros'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\breakfast_burrito'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\grilled_cheese_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\spaghetti_bolognese'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\falafel'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\poutine'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\greek_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\beef_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\fried_calamari'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\guacamole'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\ravioli'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\lobster_bisque'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\beet_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\risotto'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\crab_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\strawberry_shortcake'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\edamame'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\ceviche'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\hot_and_sour_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\spring_rolls'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\sashimi'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\paella'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\clam_chowder'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\miso_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\escargots'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\hot_dog'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pulled_pork_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\bruschetta'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\panna_cotta'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\fish_and_chips'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pad_thai'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\tiramisu'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\takoyaki'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\macarons'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\apple_pie'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\cannoli'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\scallops'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\frozen_yogurt'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_quesadilla'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\mussels'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\beef_carpaccio'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\eggs_benedict'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\spaghetti_carbonara'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\omelette'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\sushi'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\chocolate_mousse'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\beignets'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\bibimbap'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\hummus'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pork_chop'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\chicken_wings'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\grilled_salmon'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\chocolate_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\tacos'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\hamburger'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\baby_back_ribs'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pancakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\prime_rib'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pizza'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\nachos'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\macaroni_and_cheese'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\bread_pudding'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\ramen'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\croque_madame'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\lasagna'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\peking_duck'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\deviled_eggs'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_fries'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\dumplings'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\fried_rice'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\french_onion_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\pho'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\caprese_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\oysters'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\baklava'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\creme_brulee'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\carrot_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\onion_rings'.
There are 0 directories and 75 images in '101_food_classes_10_percent\train\caesar_salad'.

As before, our data is structured in the following:

10_food_classes_10_percent <- top level folder
└───train <- training images
│   └───pizza
│   │   │   1008104.jpg
│   │   │   1638227.jpg
│   │   │   ...      
│   └───steak
│       │   1000205.jpg
│       │   1647351.jpg
│       │   ...
│   
└───test <- testing images
│   └───pizza
│   │   │   1001116.jpg
│   │   │   1507019.jpg
│   │   │   ...      
│   └───steak
│       │   100274.jpg
│       │   1653815.jpg
│       │   ...

Let's use the image_dataset_from_directory() function to turn our images and labels into a tf.data.Dataset. A TensorFlow datatype, allowing us to pass it a directory to our model.

For the test dataset, we're going to set shuffle=False, so we can perform repeatable evaluation and visualization on it later.

In [31]:
# setup data inputs
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_all_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
                                                                                label_mode='categorical',
                                                                                image_size=IMG_SIZE)

test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                label_mode='categorical',
                                                                image_size=IMG_SIZE,
                                                                shuffle=False) # don't shuffle to keep experiments repeatable
Found 7575 files belonging to 101 classes.
Found 25250 files belonging to 101 classes.

Train a big dog model with transfer learning on 10% of 101 food classes¶

To keep experiments swift, we're going to start by using feature extraction transfer learning with a pre-trained model for a few epochs, then fine-tune it for a few more epochs.

Our goal is to see if we can beat base line from the original Food101 paper (accuracy 50.76%), while using only 10% of data.

  • A ModelCheckpoint callback to save progress during training, meaning we can further experiment with further training later without having to train from scratch every time
  • Data augmentation built right into the model
  • A headless (no top layers) EfficientNetB0 architecture from tf.keras.applications as our base model
  • A Dense layer with 101 hidden neurons (same as number of food classes) and softmax activation as the output layer
  • Categorical crossentropy as the loss function since we're dealing with more than two classes
  • The Adam optimizer with the default settings
  • Fitting for 5 full passes on the training data while evaluating on 15% of the test data

It seems like a lot, but these are all things covered before in part 2, from workbook 05.

Let's start with creating ModelCheckpoint callback.

Since we want the model to perform well on unseen data, we'll set it to monitor validation accuracy metric, and save model weights on the one that had the best score on said metric.

In [32]:
# create checkpoint callback to save model for later use
checkpoint_path = '101_classes_10_percent_data_model_checkpoint.weights.h5'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
    save_weights_only=True, # save only the model weights
    monitor='val_accuracy', # monitor val accuracy, to determine the weight saved, based on its resuls
    save_best_only=True) # only keep the best resulting weights, and discard the rest

Checkpoint is now ready. Let's create a small data augmentation model with sequential API. Due to reduced training data size, this will help prevent overfitting.

In [33]:
# import the required modules for model creation
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

data_augmentation = Sequential([
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomHeight(0.2),
    layers.RandomWidth(0.2),
    # preprocessing.Rescaling(1./255) # keep for ResNet40V2, remove for EfficientNetB0 as they already automate this process
], name='data_augmentation')

We'll be able to insert data_augmentation Sequential model, represented as a layer in a Functional API model. So if we want to further train a model another time, this sequential model will already be implemented in the functional API.

Now time to put it together, experimenting with feature extraction transfer learning model, using tf.keras.applications.efficientnet.EfficientNetB0 as the base model.

We'll import the base model using the parameter include_top=False, so we can add our own output layer, notably GlobalAveragePooling2D(). It condenses the output of base model into a 1D vector, which is a usable shape for the output layer, followed by a Dense layer.

In [34]:
# setup base model and freeze its layers (this will extract features)
base_model = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_shape=(224,224,3))
base_model.trainable = False

# setup model architevture with trainable top layers
inputs = layers.Input(shape=(224,224,3), name='input_layer') # shape of input image
x = data_augmentation(inputs) # augment images (will only happen during training)
x = base_model(x, training=False) # put the base model in inference mode (training=False) so we can use it to extract features, and not update weights
x = layers.GlobalAveragePooling2D(name='global_average_pooling')(x) # pool the outputs of the base model
outputs = layers.Dense(len(train_data_all_10_percent.class_names), activation='softmax', name='output_layer')(x) # same number of outputs as classes
model = tf.keras.Model(inputs, outputs)

An illustrated sigure, depicting how ourt model looks like visually in its order

Let's inspect the model

In [35]:
# get a summary of the model
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequent  (None, None, None, 3)     0         
 ial)                                                            
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequent  (None, None, None, 3)     0         
 ial)                                                            
                                                                 
 efficientnetb0 (Functional  (None, 7, 7, 1280)        4049571   
 )                                                               
                                                                 
 global_average_pooling (Gl  (None, 1280)              0         
 obalAveragePooling2D)                                           
                                                                 
 output_layer (Dense)        (None, 101)               129381    
                                                                 
=================================================================
Total params: 4178952 (15.94 MB)
Trainable params: 129381 (505.39 KB)
Non-trainable params: 4049571 (15.45 MB)
_________________________________________________________________

Nice, our functional model represents 5 layers, but each layer likely has their own layers that vary from the other.

If you notice the difference between Trainable and Non-trainable parameters, Trainable parameter only encompasses the output_layer, while base model efficientnetb0 is frozen. We're initially running feature extraction, where we keep the learned patterns of base model frozen, whilst letting output layer adjust and tune based on our custom data.

Time to compile and fit.

In [20]:
# compile
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(), # using Adam's default learning rate
metrics=['accuracy'])

# fit
history_all_classes_10_percent = model.fit(train_data_all_10_percent,
epochs=5, # fitting 5 epochs to keep experiments short
validation_data=test_data,
validation_steps=int(0.15 * len(test_data)), # evaluate on 15% of test data, again to keep experiments quick
callbacks=[checkpoint_callback]) # callback to save best model weights on file
Epoch 1/5
237/237 [==============================] - 328s 1s/step - loss: 3.3825 - accuracy: 0.2684 - val_loss: 2.4930 - val_accuracy: 0.4396
Epoch 2/5
237/237 [==============================] - 297s 1s/step - loss: 2.2078 - accuracy: 0.4939 - val_loss: 2.0097 - val_accuracy: 0.5254
Epoch 3/5
237/237 [==============================] - 292s 1s/step - loss: 1.8206 - accuracy: 0.5706 - val_loss: 1.8636 - val_accuracy: 0.5344
Epoch 4/5
237/237 [==============================] - 293s 1s/step - loss: 1.6076 - accuracy: 0.6136 - val_loss: 1.8080 - val_accuracy: 0.5384
Epoch 5/5
237/237 [==============================] - 291s 1s/step - loss: 1.4556 - accuracy: 0.6437 - val_loss: 1.7518 - val_accuracy: 0.5445

It seems the model has done impressive results, but it's only being evaluated with 15% of the test data. Let's expand it to the full test dataset.

In [21]:
# evaluate model
results_feature_extraction_model = model.evaluate(test_data)
results_feature_extraction_model
790/790 [==============================] - 635s 803ms/step - loss: 1.5837 - accuracy: 0.5846
Out[21]:
[1.5837445259094238, 0.5845940709114075]

Well, it looks as if we've just beaten the Food101 paper with 10% of the data! That's the strength of deep larning, more precisely transfer learning. Leveraging what a model has learned, into another data set.

How do the loss curves look?

In [22]:
plot_loss_curves(history_all_classes_10_percent)
No description has been provided for this image
No description has been provided for this image

Question: What should we expect the curves to suggest? Ideally, we want both curves to follow similarly to each other. If there are diversions between the two, there may be issues with overfitting or underfitting.

Fine tuning¶

Our fecture extraction transfer learning model is performing well. Why not try fine-tune a few layers in the base model, and see if improvements can be gained?

With ModelCheckpoint callback, we have the saved weights of our current performing model. So if fine tuning doesn't offer benefits to us, we can revert back to it's previous status.

To fine tune, setting trainable to True is needed on base model.

Due to our small training dataset (on purpose), we'll refreeze the model except for the last 5 layers, making them trainable.

In [23]:
# unfreeze the base model
base_model.trainable = True

# refreeze layers except the last 5 layers
for layer in base_model.layers[:-5]:
    layer.trainable = False

Now that changes were made to the functional api model, we need to recompile it to truly save the changes of the model.

Because of fine-tuning, learning rate will be lowered 10x, to ensure updates are minimal as to not heavily disturb the weights, that have been calibrated for such problems.

alt text

When fine-tuning and unfreezing layers of your pre trained model, it's common practice to lower learning rate down 10 times.

In [24]:
# recompile model with lower learning rate
model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(1e-4), # lower learning rate by 10 fold from the default
              metrics=['accuracy'])

Model recompiled. Let's check what layers are trainable

In [30]:
# what layers in the model are trainable?
for layer in model.layers:
    print(layer.name, layer.trainable)
input_layer True
data_augmentation True
efficientnetb0 True
global_average_pooling True
output_layer True
In [31]:
# check which layers are trainable
for layer_number, layer in enumerate(base_model.layers):
    print(layer_number, layer.name, layer.trainable)
0 input_3 False
1 rescaling_2 False
2 normalization_2 False
3 rescaling_3 False
4 stem_conv_pad False
5 stem_conv False
6 stem_bn False
7 stem_activation False
8 block1a_dwconv False
9 block1a_bn False
10 block1a_activation False
11 block1a_se_squeeze False
12 block1a_se_reshape False
13 block1a_se_reduce False
14 block1a_se_expand False
15 block1a_se_excite False
16 block1a_project_conv False
17 block1a_project_bn False
18 block2a_expand_conv False
19 block2a_expand_bn False
20 block2a_expand_activation False
21 block2a_dwconv_pad False
22 block2a_dwconv False
23 block2a_bn False
24 block2a_activation False
25 block2a_se_squeeze False
26 block2a_se_reshape False
27 block2a_se_reduce False
28 block2a_se_expand False
29 block2a_se_excite False
30 block2a_project_conv False
31 block2a_project_bn False
32 block2b_expand_conv False
33 block2b_expand_bn False
34 block2b_expand_activation False
35 block2b_dwconv False
36 block2b_bn False
37 block2b_activation False
38 block2b_se_squeeze False
39 block2b_se_reshape False
40 block2b_se_reduce False
41 block2b_se_expand False
42 block2b_se_excite False
43 block2b_project_conv False
44 block2b_project_bn False
45 block2b_drop False
46 block2b_add False
47 block3a_expand_conv False
48 block3a_expand_bn False
49 block3a_expand_activation False
50 block3a_dwconv_pad False
51 block3a_dwconv False
52 block3a_bn False
53 block3a_activation False
54 block3a_se_squeeze False
55 block3a_se_reshape False
56 block3a_se_reduce False
57 block3a_se_expand False
58 block3a_se_excite False
59 block3a_project_conv False
60 block3a_project_bn False
61 block3b_expand_conv False
62 block3b_expand_bn False
63 block3b_expand_activation False
64 block3b_dwconv False
65 block3b_bn False
66 block3b_activation False
67 block3b_se_squeeze False
68 block3b_se_reshape False
69 block3b_se_reduce False
70 block3b_se_expand False
71 block3b_se_excite False
72 block3b_project_conv False
73 block3b_project_bn False
74 block3b_drop False
75 block3b_add False
76 block4a_expand_conv False
77 block4a_expand_bn False
78 block4a_expand_activation False
79 block4a_dwconv_pad False
80 block4a_dwconv False
81 block4a_bn False
82 block4a_activation False
83 block4a_se_squeeze False
84 block4a_se_reshape False
85 block4a_se_reduce False
86 block4a_se_expand False
87 block4a_se_excite False
88 block4a_project_conv False
89 block4a_project_bn False
90 block4b_expand_conv False
91 block4b_expand_bn False
92 block4b_expand_activation False
93 block4b_dwconv False
94 block4b_bn False
95 block4b_activation False
96 block4b_se_squeeze False
97 block4b_se_reshape False
98 block4b_se_reduce False
99 block4b_se_expand False
100 block4b_se_excite False
101 block4b_project_conv False
102 block4b_project_bn False
103 block4b_drop False
104 block4b_add False
105 block4c_expand_conv False
106 block4c_expand_bn False
107 block4c_expand_activation False
108 block4c_dwconv False
109 block4c_bn False
110 block4c_activation False
111 block4c_se_squeeze False
112 block4c_se_reshape False
113 block4c_se_reduce False
114 block4c_se_expand False
115 block4c_se_excite False
116 block4c_project_conv False
117 block4c_project_bn False
118 block4c_drop False
119 block4c_add False
120 block5a_expand_conv False
121 block5a_expand_bn False
122 block5a_expand_activation False
123 block5a_dwconv False
124 block5a_bn False
125 block5a_activation False
126 block5a_se_squeeze False
127 block5a_se_reshape False
128 block5a_se_reduce False
129 block5a_se_expand False
130 block5a_se_excite False
131 block5a_project_conv False
132 block5a_project_bn False
133 block5b_expand_conv False
134 block5b_expand_bn False
135 block5b_expand_activation False
136 block5b_dwconv False
137 block5b_bn False
138 block5b_activation False
139 block5b_se_squeeze False
140 block5b_se_reshape False
141 block5b_se_reduce False
142 block5b_se_expand False
143 block5b_se_excite False
144 block5b_project_conv False
145 block5b_project_bn False
146 block5b_drop False
147 block5b_add False
148 block5c_expand_conv False
149 block5c_expand_bn False
150 block5c_expand_activation False
151 block5c_dwconv False
152 block5c_bn False
153 block5c_activation False
154 block5c_se_squeeze False
155 block5c_se_reshape False
156 block5c_se_reduce False
157 block5c_se_expand False
158 block5c_se_excite False
159 block5c_project_conv False
160 block5c_project_bn False
161 block5c_drop False
162 block5c_add False
163 block6a_expand_conv False
164 block6a_expand_bn False
165 block6a_expand_activation False
166 block6a_dwconv_pad False
167 block6a_dwconv False
168 block6a_bn False
169 block6a_activation False
170 block6a_se_squeeze False
171 block6a_se_reshape False
172 block6a_se_reduce False
173 block6a_se_expand False
174 block6a_se_excite False
175 block6a_project_conv False
176 block6a_project_bn False
177 block6b_expand_conv False
178 block6b_expand_bn False
179 block6b_expand_activation False
180 block6b_dwconv False
181 block6b_bn False
182 block6b_activation False
183 block6b_se_squeeze False
184 block6b_se_reshape False
185 block6b_se_reduce False
186 block6b_se_expand False
187 block6b_se_excite False
188 block6b_project_conv False
189 block6b_project_bn False
190 block6b_drop False
191 block6b_add False
192 block6c_expand_conv False
193 block6c_expand_bn False
194 block6c_expand_activation False
195 block6c_dwconv False
196 block6c_bn False
197 block6c_activation False
198 block6c_se_squeeze False
199 block6c_se_reshape False
200 block6c_se_reduce False
201 block6c_se_expand False
202 block6c_se_excite False
203 block6c_project_conv False
204 block6c_project_bn False
205 block6c_drop False
206 block6c_add False
207 block6d_expand_conv False
208 block6d_expand_bn False
209 block6d_expand_activation False
210 block6d_dwconv False
211 block6d_bn False
212 block6d_activation False
213 block6d_se_squeeze False
214 block6d_se_reshape False
215 block6d_se_reduce False
216 block6d_se_expand False
217 block6d_se_excite False
218 block6d_project_conv False
219 block6d_project_bn False
220 block6d_drop False
221 block6d_add False
222 block7a_expand_conv False
223 block7a_expand_bn False
224 block7a_expand_activation False
225 block7a_dwconv False
226 block7a_bn False
227 block7a_activation False
228 block7a_se_squeeze False
229 block7a_se_reshape False
230 block7a_se_reduce False
231 block7a_se_expand False
232 block7a_se_excite False
233 block7a_project_conv True
234 block7a_project_bn True
235 top_conv True
236 top_bn True
237 top_activation True

Nice, time to fine tune the model.

Another 5 epochs should be enough to see if it benefits the model or not, though more epochs won't hurt as well.

We'll start the training, where feature extraction has left us off using the initial_epoch parameter in the fit() function.

In [33]:
# fine tune 5 more epochs
fine_tune_epochs = 10 # model has done 5 epochs. The is the total epochs we're after (5 initially, and another 5)

history_all_classes_10_percent_fine_tune = model.fit(train_data_all_10_percent,
                                                     epochs=fine_tune_epochs,
                                                     validation_data=test_data,
                                                     validation_steps=int(0.15*len(test_data)),
                                                     initial_epoch=history_all_classes_10_percent.epoch[-1]) # start from previous last epoch
Epoch 5/10
237/237 [==============================] - 355s 1s/step - loss: 1.2106 - accuracy: 0.6825 - val_loss: 1.6916 - val_accuracy: 0.5559
Epoch 6/10
237/237 [==============================] - 333s 1s/step - loss: 1.0890 - accuracy: 0.7067 - val_loss: 1.7174 - val_accuracy: 0.5493
Epoch 7/10
237/237 [==============================] - 327s 1s/step - loss: 1.0212 - accuracy: 0.7298 - val_loss: 1.6693 - val_accuracy: 0.5612
Epoch 8/10
237/237 [==============================] - 331s 1s/step - loss: 0.9446 - accuracy: 0.7440 - val_loss: 1.7185 - val_accuracy: 0.5519
Epoch 9/10
237/237 [==============================] - 332s 1s/step - loss: 0.8885 - accuracy: 0.7625 - val_loss: 1.7214 - val_accuracy: 0.5551
Epoch 10/10
237/237 [==============================] - 331s 1s/step - loss: 0.8269 - accuracy: 0.7754 - val_loss: 1.7270 - val_accuracy: 0.5506

Once again, we're only evaluating on a small portion of the test data. Let's find out how well the model does on full test data.

In [34]:
# evaluate fine-tuned model on the whole test dataset
results_all_classes_10_percent_fine_tune = model.evaluate(test_data)
results_all_classes_10_percent_fine_tune
790/790 [==============================] - 684s 865ms/step - loss: 1.5024 - accuracy: 0.6017
Out[34]:
[1.5023630857467651, 0.60166335105896]

It seems like there's minimal improvements unfortunately.

We may get a better idea by using compare_historys() function, and seeing how the training curves look like on the graph.

In [40]:
compare_historys(original_history=history_all_classes_10_percent,
                 new_history=history_all_classes_10_percent_fine_tune,
                 initial_epochs=5)
No description has been provided for this image

With fine-tuning, we can see how on training data, the accuracy has improved significantly and its trend seems to show it will continue to increase accuracy. Though validation data has not reflected this trend, and barely kept up. It's showing classic signs of overfitting.

This is fine though, as fine-tuning often leads to overfitting of training data, due to the pre-trained model, having been already been trained on similar data to our custom problem.

For our case, the pre-trained model EfficientNetB0 was trained on ImageNet. Which contained many real life photos of food, just like our food dataset.

If feature-extraction already works well, fine-tuning may not be as nessecary for further improvement. Dataset with images that differ away from the pre-trained model's data, will typically benefit with fine-tuning.

Saving our trained model¶

To prevent the need to retrain the model again, let's save it physically with the save() method.

In [41]:
# Save model to computer for later use if needed
model.save('101_food_class_10_percent_saved_big_dog_model')
INFO:tensorflow:Assets written to: 101_food_class_10_percent_saved_big_dog_model\assets
INFO:tensorflow:Assets written to: 101_food_class_10_percent_saved_big_dog_model\assets

Evaluating the performance of the big dog model across all different classes¶

We've got a trained and saved model, whigh has done fairly well from evaluation on test dataset.

Let's go deeper into model's performance and get some visualizations going.

For that, we can load the saved model and use it to make some predictions on the test dataset.

Note: Evaluating a ML model is as important as training one. Final metrics can be decieving. Always visualize the model's performance on unseen data to make sure you aren't being fooled with good looking numbers.

In [36]:
import tensorflow as tf
# download pre-trained model from google storage
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
saved_model_path = '06_101_food_class_10_percent_saved_big_dog_model.zip'
unzip_data(saved_model_path)

model = tf.keras.models.load_model(saved_model_path.split('.')[0])
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0 44.5M    0  174k    0     0   152k      0  0:04:58  0:00:01  0:04:57  153k
  1 44.5M    1  812k    0     0   542k      0  0:01:24  0:00:01  0:01:23  543k
 24 44.5M   24 10.7M    0     0  4396k      0  0:00:10  0:00:02  0:00:08 4401k
 49 44.5M   49 22.2M    0     0  6503k      0  0:00:07  0:00:03  0:00:04 6508k
 74 44.5M   74 33.0M    0     0  7536k      0  0:00:06  0:00:04  0:00:02 7540k
 97 44.5M   97 43.6M    0     0  8140k      0  0:00:05  0:00:05 --:--:--  9.9M
100 44.5M  100 44.5M    0     0  8195k      0  0:00:05  0:00:05 --:--:-- 10.7M
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_419470) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_446460) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_450449) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_415747) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_416083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_450775) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_451847) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_417915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_451887) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_452467) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_438312) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_417583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_418582) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_454031) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_455436) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_415524) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_451474) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_451768) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_441729) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_454357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_416695) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_454238) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_436681) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_415804) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_452919) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_453658) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_448082) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_418915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_453539) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_452586) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_450163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_418018) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_455357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_417639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_451188) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_420190) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_415468) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_455476) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_417354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_452213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_452173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_415571) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_451514) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_417971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_454730) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_416742) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_450489) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_451148) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_418194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_416463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_429711) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_443351) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_418526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_453245) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_416416) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_428089) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_416027) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_453912) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_452546) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_420237) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_418629) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_416359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_451395) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_454690) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_419905) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_419526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_418297) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_452094) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference__wrapped_model_408990) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_453618) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_454984) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_450696) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_418858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_450044) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_418250) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_453991) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_453285) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_416971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_455683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_415851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_453166) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_450123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_417075) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_452840) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_417307) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_455063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_419802) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_419858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_452959) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_451069) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_450370) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_419138) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_419194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_420134) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_417028) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_454611) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_416639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_417686) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_417251) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_455103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_450815) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_416130) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_454317) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_418962) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_419241) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.

To make sure our loaded model is indeed a trained model, let's evaluate its performance on the test dataset.

In [46]:
# check to see if loaded model is a trained model
loaded_loss, loaded_accuracy = model.evaluate(test_data)
loaded_loss, loaded_accuracy
790/790 [==============================] - 711s 898ms/step - loss: 1.8027 - accuracy: 0.6078
Out[46]:
(1.8027207851409912, 0.6077623963356018)

Looks like the loaded model is doing just as well as when it was during evaluation of full test data.

Making predictions with our trained model¶

To evaluate the trained model, we need to make some predictions with it, and compre the predictions to the test dataset.

As the model has never seen the test dataset, this should give us an indication of how the model will perform in the real world, on similar data that it has been trained on.

To make predictions on the model, we use predict() when passing on the test data.

In [37]:
# make predictions with model
pred_probs = model.predict(test_data, verbose=1) # set verbosity to see how long it will take
790/790 [==============================] - 661s 834ms/step

We just passed all test images to the model, and asked for it to make a prediction per image of food.

So how many predictions are made?

In [38]:
# how many predictions are there?
len(pred_probs)
Out[38]:
25250

with each image, possibly being out of 101 classes, thats 25,250 images, with 101 possibilities.

In [49]:
# what's the shape of our predictions?
pred_probs.shape
Out[49]:
(25250, 101)

What we have is called a predictions probability tensor (or array).

Let's see what the first 10 looks like

In [50]:
pred_probs[:10]
Out[50]:
array([[5.9541803e-02, 3.5742237e-06, 4.1377347e-02, ..., 1.4138752e-09,
        8.3531268e-05, 3.0897614e-03],
       [9.6401668e-01, 1.3753035e-09, 8.4780005e-04, ..., 5.4287146e-05,
        7.8361458e-12, 9.8464892e-10],
       [9.5925862e-01, 3.2533895e-05, 1.4867117e-03, ..., 7.1892083e-07,
        5.4396531e-07, 4.0275998e-05],
       ...,
       [4.7313362e-01, 1.2931211e-07, 1.4805610e-03, ..., 5.9750286e-04,
        6.6968983e-05, 2.3469302e-05],
       [4.4571716e-02, 4.7265306e-07, 1.2258544e-01, ..., 6.3498819e-06,
        7.5318171e-06, 3.6778876e-03],
       [7.2438931e-01, 1.9249797e-09, 5.2311167e-05, ..., 1.2291438e-03,
        1.5792799e-09, 9.6395903e-05]], dtype=float32)

It looks to be a bunch of tensors of really small numbers. How about we zoom into one of the tensor?

In [51]:
# we get one prediction probability per class, for each image
print(f'Number of prediction probabilities for sample 0: {len(pred_probs[0])}')
print(f'What prediction probability sample 0 looks like: \n {pred_probs[0]}')
print(f'The class with the highest predicted probability by the mode for sample 0: {pred_probs[0].argmax()}')
Number of prediction probabilities for sample 0: 101
What prediction probability sample 0 looks like: 
 [5.95418029e-02 3.57422368e-06 4.13773470e-02 1.06606712e-09
 8.16151680e-09 8.66406058e-09 8.09273104e-07 8.56533518e-07
 1.98592134e-05 8.09785718e-07 3.17280868e-09 9.86751957e-07
 2.85323913e-04 7.80500442e-10 7.42301228e-04 3.89165434e-05
 6.47412026e-06 2.49774348e-06 3.78915465e-05 2.06784350e-07
 1.55385478e-05 8.15079147e-07 2.62307526e-06 2.00107877e-07
 8.38284564e-07 5.42161115e-06 3.73912280e-06 1.31505269e-08
 2.77616014e-03 2.80517943e-05 6.85629054e-10 2.55749892e-05
 1.66890954e-04 7.64081243e-10 4.04532795e-04 1.31507765e-08
 1.79575227e-06 1.44483045e-06 2.30629761e-02 8.24671304e-07
 8.53669519e-07 1.71386864e-06 7.05258026e-06 1.84024014e-08
 2.85536885e-07 7.94840162e-06 2.06818777e-06 1.85252830e-07
 3.36200756e-08 3.15226294e-04 1.04110168e-05 8.54497102e-07
 8.47418189e-01 1.05554800e-05 4.40948554e-07 3.74044794e-05
 3.53065443e-05 3.24891153e-05 6.73152244e-05 1.28526594e-08
 2.62199956e-10 1.03182419e-05 8.57441555e-05 1.05699201e-06
 2.12935470e-06 3.76377102e-05 7.59745546e-08 2.53405946e-04
 9.29065152e-07 1.25982158e-04 6.26223436e-06 1.24587913e-08
 4.05197461e-05 6.87283404e-08 1.25463464e-06 5.28879660e-08
 7.54253193e-08 7.53988934e-05 7.75409208e-05 6.40267046e-07
 9.90336275e-07 2.22261660e-05 1.50140704e-05 1.40385367e-07
 1.22326192e-05 1.90447737e-02 5.00000533e-05 4.62264643e-06
 1.53884358e-07 3.38243041e-07 3.92285360e-09 1.65638838e-07
 8.13211809e-05 4.89655076e-06 2.40683391e-07 2.31242102e-05
 3.10408417e-04 3.13802557e-05 1.41387524e-09 8.35312676e-05
 3.08976136e-03]
The class with the highest predicted probability by the mode for sample 0: 52

For every image tensor we pass to the model, due to the ouput number of neurons, and the chosen activation function of the last layer, (layers.Dense(len(train_data_all_10_percent.class_names), activation='softmax')) it outputs a prediction probability between 0 and 1, for all 101 classes.

You can consider the index with the highest value in prediction probability to be what the model thinks is the most likely label.

Note: The nature of softmax, is that there is a value of '1', and must be distributed out to all classes (aka all classes sum up to 1).

We can find the indexed class with the highest value, using argmax() method.

In [54]:
# get the class predictions of each label
pred_classes = pred_probs.argmax(axis=1)

# how do they look?
pred_classes[:10]
Out[54]:
array([52,  0,  0, 80, 79, 61, 29,  0, 85,  0], dtype=int64)

Now we got the predicted class index for each of the samples in our test dataset :)

We'll be able to compare them to the test dataset labels, and further evaluate model.

To get test dataset labels, we'll need to unravel test_data object (which is a form of tf.data.Dataset) using the unbatch() method.

Doing so will give us access to the images and labels. The labels are one-hot encoded, making the argmax() method useful to find the indexed location of the class.

Note: This is why shuffle=False is essential forour test dataset. If it shuffles everytime we use it, the location of say, image[0], will be in a completely different location, making it impossible to compare.

In [40]:
# Note: this might take a minute or so due to unravelling 790 batches
y_labels = []
for images, labels in test_data.unbatch(): # unbatch the test data and get images and labels
    y_labels.append(labels.numpy().argmax()) # append the index which has largest value
y_labels[:10]
Out[40]:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

The final check is to see how many labels we have.

In [41]:
# how many labels are there?
len(y_labels)
Out[41]:
25250

As expected, the number of labels match the number of images. Time to compare the prediction to the labels.

Evaluating our models predictions¶

A very simple evaluation is to use Scikit-learn's accuracy_score() function which compares truth labels to predicted labels and returns an accuracy score.

If both datasets are correct, we should have roughly the same accuracy value as when we did the .evaluate() method earlier.

In [55]:
# get accuracy score by comparing predicted clsasses to ground truth labels
from sklearn.metrics import accuracy_score

sklearn_accuracy = accuracy_score(y_labels,pred_classes)
sklearn_accuracy
Out[55]:
0.6077623762376237
In [ ]:
# does the .evaluate() value match closely to the value above?
import numpy as np
print(f'Close? {np.isclose(loaded_accuracy, sklearn_accuracy)} | Difference: {loaded_accuracy - sklearn_accuracy}')

It looks like the orders of both dataset is correct.

How about we visualize this in a confusion matrix? We'll make use of make_confusion_matrix function from the helper function.

In [73]:
# import confusion matrix from helper function
from helper_functions import make_confusion_matrix
In [74]:
# note: the following confusion matrix code is a remix of scikit-learn's plot_confusion_matrix function
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Get the class names to input for confusion matrix
class_names = test_data.class_names

# Plot a confusion matrix with all 25250 predictions, ground truth labels and 101 classes
make_confusion_matrix(y_true=y_labels,
                      y_pred=pred_classes,
                      classes=class_names,
                      figsize=(100, 100),
                      text_size=20,
                      norm=False,
                      savefig=True)
No description has been provided for this image

This is a very big confusion matrix. It may look daunting at first, but zooming in a lot can say a lot of insight on which classes get 'confused' on the most, and which ones they are often predicting.

Good news is the majority line up diagonally from top left to bottom right. Indicating they're matching to the exact class in both x and y axis.

It seems the model gets confused most often on visually similar foods. Like fillet_mignon to pork_chop, or chocolate_cake to tiramisu.

Since this is a classification problem, we can further evaluate the model's predictions using Scikit-Learn's classification_report() function.

In [57]:
from sklearn.metrics import classification_report
print(classification_report(y_labels, pred_classes))
              precision    recall  f1-score   support

           0       0.29      0.20      0.24       250
           1       0.51      0.69      0.59       250
           2       0.56      0.65      0.60       250
           3       0.74      0.53      0.62       250
           4       0.73      0.43      0.54       250
           5       0.34      0.54      0.42       250
           6       0.67      0.79      0.72       250
           7       0.82      0.76      0.79       250
           8       0.40      0.37      0.39       250
           9       0.62      0.44      0.51       250
          10       0.62      0.42      0.50       250
          11       0.84      0.49      0.62       250
          12       0.52      0.74      0.61       250
          13       0.56      0.60      0.58       250
          14       0.56      0.59      0.57       250
          15       0.44      0.32      0.37       250
          16       0.45      0.75      0.57       250
          17       0.37      0.51      0.43       250
          18       0.43      0.60      0.50       250
          19       0.68      0.60      0.64       250
          20       0.68      0.75      0.71       250
          21       0.35      0.64      0.45       250
          22       0.30      0.37      0.33       250
          23       0.66      0.77      0.71       250
          24       0.83      0.72      0.77       250
          25       0.76      0.71      0.73       250
          26       0.51      0.42      0.46       250
          27       0.78      0.72      0.75       250
          28       0.70      0.69      0.69       250
          29       0.70      0.68      0.69       250
          30       0.92      0.63      0.75       250
          31       0.78      0.70      0.74       250
          32       0.75      0.83      0.79       250
          33       0.89      0.98      0.94       250
          34       0.68      0.78      0.72       250
          35       0.78      0.66      0.72       250
          36       0.53      0.56      0.55       250
          37       0.30      0.55      0.39       250
          38       0.78      0.63      0.69       250
          39       0.27      0.33      0.30       250
          40       0.72      0.81      0.76       250
          41       0.81      0.62      0.70       250
          42       0.50      0.58      0.54       250
          43       0.75      0.60      0.67       250
          44       0.74      0.45      0.56       250
          45       0.77      0.85      0.81       250
          46       0.81      0.46      0.58       250
          47       0.44      0.49      0.46       250
          48       0.45      0.81      0.58       250
          49       0.50      0.44      0.47       250
          50       0.54      0.39      0.46       250
          51       0.71      0.86      0.78       250
          52       0.51      0.77      0.61       250
          53       0.67      0.68      0.68       250
          54       0.88      0.75      0.81       250
          55       0.86      0.69      0.76       250
          56       0.56      0.24      0.34       250
          57       0.62      0.45      0.52       250
          58       0.68      0.58      0.62       250
          59       0.70      0.37      0.49       250
          60       0.83      0.59      0.69       250
          61       0.54      0.81      0.65       250
          62       0.72      0.49      0.58       250
          63       0.94      0.86      0.90       250
          64       0.78      0.85      0.81       250
          65       0.82      0.82      0.82       250
          66       0.69      0.32      0.44       250
          67       0.41      0.58      0.48       250
          68       0.90      0.78      0.83       250
          69       0.84      0.82      0.83       250
          70       0.62      0.83      0.71       250
          71       0.81      0.46      0.59       250
          72       0.64      0.65      0.65       250
          73       0.51      0.44      0.47       250
          74       0.72      0.61      0.66       250
          75       0.84      0.90      0.87       250
          76       0.78      0.78      0.78       250
          77       0.36      0.27      0.31       250
          78       0.79      0.74      0.76       250
          79       0.44      0.81      0.57       250
          80       0.57      0.60      0.59       250
          81       0.65      0.70      0.68       250
          82       0.38      0.31      0.34       250
          83       0.58      0.80      0.67       250
          84       0.61      0.38      0.47       250
          85       0.44      0.74      0.55       250
          86       0.71      0.86      0.78       250
          87       0.41      0.39      0.40       250
          88       0.83      0.80      0.81       250
          89       0.71      0.31      0.43       250
          90       0.92      0.69      0.79       250
          91       0.83      0.87      0.85       250
          92       0.68      0.65      0.67       250
          93       0.31      0.38      0.34       250
          94       0.61      0.54      0.57       250
          95       0.74      0.61      0.67       250
          96       0.56      0.29      0.38       250
          97       0.45      0.74      0.56       250
          98       0.47      0.33      0.39       250
          99       0.52      0.27      0.35       250
         100       0.59      0.70      0.64       250

    accuracy                           0.61     25250
   macro avg       0.63      0.61      0.61     25250
weighted avg       0.63      0.61      0.61     25250

The classification_report() outputs precision, recall, and f1-scores per class.

A refresher:

  • Precision: Proportion of true positives, over the total number of positive samples being predicted. Higher precision means less false positives (real answer = 0, but predicted 1).
  • Recall: Proportion of true positives, over the total number of positives that are in the dataset. Higher recall means less false negatives (real answer = 1, but predicted 0).
  • f1-score: Combines precision and recall into one metric. Higher the score, more accurate the model is.

The above output is helpful, but hard to understand with so much classes that it gets hard to understand.

Let's see if we can make it easire through visualization.

We'll get classification_report() as a dictionary using output_dict=True.

In [71]:
# get a dictionary for the classification report
classification_report_dict = classification_report(y_labels, pred_classes, output_dict=True)
classification_report_dict
Out[71]:
{'0': {'precision': 0.29310344827586204,
  'recall': 0.204,
  'f1-score': 0.24056603773584906,
  'support': 250.0},
 '1': {'precision': 0.5088235294117647,
  'recall': 0.692,
  'f1-score': 0.5864406779661017,
  'support': 250.0},
 '2': {'precision': 0.5625,
  'recall': 0.648,
  'f1-score': 0.6022304832713755,
  'support': 250.0},
 '3': {'precision': 0.7415730337078652,
  'recall': 0.528,
  'f1-score': 0.616822429906542,
  'support': 250.0},
 '4': {'precision': 0.7346938775510204,
  'recall': 0.432,
  'f1-score': 0.5440806045340051,
  'support': 250.0},
 '5': {'precision': 0.34177215189873417,
  'recall': 0.54,
  'f1-score': 0.4186046511627907,
  'support': 250.0},
 '6': {'precision': 0.6677966101694915,
  'recall': 0.788,
  'f1-score': 0.7229357798165138,
  'support': 250.0},
 '7': {'precision': 0.8197424892703863,
  'recall': 0.764,
  'f1-score': 0.7908902691511387,
  'support': 250.0},
 '8': {'precision': 0.4025974025974026,
  'recall': 0.372,
  'f1-score': 0.3866943866943867,
  'support': 250.0},
 '9': {'precision': 0.6193181818181818,
  'recall': 0.436,
  'f1-score': 0.5117370892018779,
  'support': 250.0},
 '10': {'precision': 0.6235294117647059,
  'recall': 0.424,
  'f1-score': 0.5047619047619047,
  'support': 250.0},
 '11': {'precision': 0.8356164383561644,
  'recall': 0.488,
  'f1-score': 0.6161616161616161,
  'support': 250.0},
 '12': {'precision': 0.5196629213483146,
  'recall': 0.74,
  'f1-score': 0.6105610561056105,
  'support': 250.0},
 '13': {'precision': 0.5601503759398496,
  'recall': 0.596,
  'f1-score': 0.5775193798449613,
  'support': 250.0},
 '14': {'precision': 0.5584905660377358,
  'recall': 0.592,
  'f1-score': 0.574757281553398,
  'support': 250.0},
 '15': {'precision': 0.4388888888888889,
  'recall': 0.316,
  'f1-score': 0.3674418604651163,
  'support': 250.0},
 '16': {'precision': 0.4530120481927711,
  'recall': 0.752,
  'f1-score': 0.5654135338345865,
  'support': 250.0},
 '17': {'precision': 0.3659942363112392,
  'recall': 0.508,
  'f1-score': 0.42546063651591287,
  'support': 250.0},
 '18': {'precision': 0.4318840579710145,
  'recall': 0.596,
  'f1-score': 0.5008403361344538,
  'support': 250.0},
 '19': {'precision': 0.6832579185520362,
  'recall': 0.604,
  'f1-score': 0.6411889596602972,
  'support': 250.0},
 '20': {'precision': 0.68,
  'recall': 0.748,
  'f1-score': 0.7123809523809523,
  'support': 250.0},
 '21': {'precision': 0.350109409190372,
  'recall': 0.64,
  'f1-score': 0.4526166902404526,
  'support': 250.0},
 '22': {'precision': 0.2977346278317152,
  'recall': 0.368,
  'f1-score': 0.3291592128801431,
  'support': 250.0},
 '23': {'precision': 0.6632302405498282,
  'recall': 0.772,
  'f1-score': 0.7134935304990758,
  'support': 250.0},
 '24': {'precision': 0.8294930875576036,
  'recall': 0.72,
  'f1-score': 0.7708779443254818,
  'support': 250.0},
 '25': {'precision': 0.7574468085106383,
  'recall': 0.712,
  'f1-score': 0.734020618556701,
  'support': 250.0},
 '26': {'precision': 0.5147058823529411,
  'recall': 0.42,
  'f1-score': 0.46255506607929514,
  'support': 250.0},
 '27': {'precision': 0.776824034334764,
  'recall': 0.724,
  'f1-score': 0.7494824016563147,
  'support': 250.0},
 '28': {'precision': 0.6991869918699187,
  'recall': 0.688,
  'f1-score': 0.6935483870967742,
  'support': 250.0},
 '29': {'precision': 0.7024793388429752,
  'recall': 0.68,
  'f1-score': 0.6910569105691057,
  'support': 250.0},
 '30': {'precision': 0.9235294117647059,
  'recall': 0.628,
  'f1-score': 0.7476190476190476,
  'support': 250.0},
 '31': {'precision': 0.7802690582959642,
  'recall': 0.696,
  'f1-score': 0.7357293868921776,
  'support': 250.0},
 '32': {'precision': 0.7472924187725631,
  'recall': 0.828,
  'f1-score': 0.7855787476280834,
  'support': 250.0},
 '33': {'precision': 0.8945454545454545,
  'recall': 0.984,
  'f1-score': 0.9371428571428572,
  'support': 250.0},
 '34': {'precision': 0.6783216783216783,
  'recall': 0.776,
  'f1-score': 0.7238805970149254,
  'support': 250.0},
 '35': {'precision': 0.7819905213270142,
  'recall': 0.66,
  'f1-score': 0.7158351409978309,
  'support': 250.0},
 '36': {'precision': 0.5320754716981132,
  'recall': 0.564,
  'f1-score': 0.5475728155339806,
  'support': 250.0},
 '37': {'precision': 0.29912663755458513,
  'recall': 0.548,
  'f1-score': 0.3870056497175141,
  'support': 250.0},
 '38': {'precision': 0.7772277227722773,
  'recall': 0.628,
  'f1-score': 0.6946902654867256,
  'support': 250.0},
 '39': {'precision': 0.2694805194805195,
  'recall': 0.332,
  'f1-score': 0.2974910394265233,
  'support': 250.0},
 '40': {'precision': 0.7214285714285714,
  'recall': 0.808,
  'f1-score': 0.7622641509433963,
  'support': 250.0},
 '41': {'precision': 0.8115183246073299,
  'recall': 0.62,
  'f1-score': 0.7029478458049887,
  'support': 250.0},
 '42': {'precision': 0.5,
  'recall': 0.58,
  'f1-score': 0.5370370370370371,
  'support': 250.0},
 '43': {'precision': 0.746268656716418,
  'recall': 0.6,
  'f1-score': 0.6651884700665188,
  'support': 250.0},
 '44': {'precision': 0.7417218543046358,
  'recall': 0.448,
  'f1-score': 0.5586034912718204,
  'support': 250.0},
 '45': {'precision': 0.7745454545454545,
  'recall': 0.852,
  'f1-score': 0.8114285714285714,
  'support': 250.0},
 '46': {'precision': 0.8085106382978723,
  'recall': 0.456,
  'f1-score': 0.5831202046035806,
  'support': 250.0},
 '47': {'precision': 0.4392857142857143,
  'recall': 0.492,
  'f1-score': 0.4641509433962264,
  'support': 250.0},
 '48': {'precision': 0.4481236203090508,
  'recall': 0.812,
  'f1-score': 0.577524893314367,
  'support': 250.0},
 '49': {'precision': 0.5045454545454545,
  'recall': 0.444,
  'f1-score': 0.4723404255319149,
  'support': 250.0},
 '50': {'precision': 0.5444444444444444,
  'recall': 0.392,
  'f1-score': 0.4558139534883721,
  'support': 250.0},
 '51': {'precision': 0.7081967213114754,
  'recall': 0.864,
  'f1-score': 0.7783783783783784,
  'support': 250.0},
 '52': {'precision': 0.5092838196286472,
  'recall': 0.768,
  'f1-score': 0.6124401913875598,
  'support': 250.0},
 '53': {'precision': 0.6719367588932806,
  'recall': 0.68,
  'f1-score': 0.6759443339960238,
  'support': 250.0},
 '54': {'precision': 0.8785046728971962,
  'recall': 0.752,
  'f1-score': 0.8103448275862069,
  'support': 250.0},
 '55': {'precision': 0.86,
  'recall': 0.688,
  'f1-score': 0.7644444444444445,
  'support': 250.0},
 '56': {'precision': 0.5596330275229358,
  'recall': 0.244,
  'f1-score': 0.3398328690807799,
  'support': 250.0},
 '57': {'precision': 0.6222222222222222,
  'recall': 0.448,
  'f1-score': 0.5209302325581395,
  'support': 250.0},
 '58': {'precision': 0.6792452830188679,
  'recall': 0.576,
  'f1-score': 0.6233766233766234,
  'support': 250.0},
 '59': {'precision': 0.7045454545454546,
  'recall': 0.372,
  'f1-score': 0.4869109947643979,
  'support': 250.0},
 '60': {'precision': 0.8305084745762712,
  'recall': 0.588,
  'f1-score': 0.6885245901639344,
  'support': 250.0},
 '61': {'precision': 0.543010752688172,
  'recall': 0.808,
  'f1-score': 0.6495176848874598,
  'support': 250.0},
 '62': {'precision': 0.7218934911242604,
  'recall': 0.488,
  'f1-score': 0.5823389021479713,
  'support': 250.0},
 '63': {'precision': 0.9385964912280702,
  'recall': 0.856,
  'f1-score': 0.895397489539749,
  'support': 250.0},
 '64': {'precision': 0.7773722627737226,
  'recall': 0.852,
  'f1-score': 0.8129770992366412,
  'support': 250.0},
 '65': {'precision': 0.82, 'recall': 0.82, 'f1-score': 0.82, 'support': 250.0},
 '66': {'precision': 0.6923076923076923,
  'recall': 0.324,
  'f1-score': 0.44141689373297005,
  'support': 250.0},
 '67': {'precision': 0.4090909090909091,
  'recall': 0.576,
  'f1-score': 0.47840531561461797,
  'support': 250.0},
 '68': {'precision': 0.8981481481481481,
  'recall': 0.776,
  'f1-score': 0.8326180257510729,
  'support': 250.0},
 '69': {'precision': 0.8442622950819673,
  'recall': 0.824,
  'f1-score': 0.8340080971659919,
  'support': 250.0},
 '70': {'precision': 0.6216216216216216,
  'recall': 0.828,
  'f1-score': 0.7101200686106347,
  'support': 250.0},
 '71': {'precision': 0.8111888111888111,
  'recall': 0.464,
  'f1-score': 0.5903307888040712,
  'support': 250.0},
 '72': {'precision': 0.6417322834645669,
  'recall': 0.652,
  'f1-score': 0.6468253968253969,
  'support': 250.0},
 '73': {'precision': 0.5091743119266054,
  'recall': 0.444,
  'f1-score': 0.47435897435897434,
  'support': 250.0},
 '74': {'precision': 0.7169811320754716,
  'recall': 0.608,
  'f1-score': 0.658008658008658,
  'support': 250.0},
 '75': {'precision': 0.8389513108614233,
  'recall': 0.896,
  'f1-score': 0.8665377176015474,
  'support': 250.0},
 '76': {'precision': 0.7777777777777778,
  'recall': 0.784,
  'f1-score': 0.7808764940239044,
  'support': 250.0},
 '77': {'precision': 0.3641304347826087,
  'recall': 0.268,
  'f1-score': 0.3087557603686636,
  'support': 250.0},
 '78': {'precision': 0.7863247863247863,
  'recall': 0.736,
  'f1-score': 0.7603305785123967,
  'support': 250.0},
 '79': {'precision': 0.44130434782608696,
  'recall': 0.812,
  'f1-score': 0.571830985915493,
  'support': 250.0},
 '80': {'precision': 0.5747126436781609,
  'recall': 0.6,
  'f1-score': 0.5870841487279843,
  'support': 250.0},
 '81': {'precision': 0.6529850746268657,
  'recall': 0.7,
  'f1-score': 0.6756756756756757,
  'support': 250.0},
 '82': {'precision': 0.3804878048780488,
  'recall': 0.312,
  'f1-score': 0.34285714285714286,
  'support': 250.0},
 '83': {'precision': 0.5780346820809249,
  'recall': 0.8,
  'f1-score': 0.6711409395973155,
  'support': 250.0},
 '84': {'precision': 0.6103896103896104,
  'recall': 0.376,
  'f1-score': 0.46534653465346537,
  'support': 250.0},
 '85': {'precision': 0.4423076923076923,
  'recall': 0.736,
  'f1-score': 0.5525525525525525,
  'support': 250.0},
 '86': {'precision': 0.7081967213114754,
  'recall': 0.864,
  'f1-score': 0.7783783783783784,
  'support': 250.0},
 '87': {'precision': 0.40756302521008403,
  'recall': 0.388,
  'f1-score': 0.3975409836065574,
  'support': 250.0},
 '88': {'precision': 0.8264462809917356,
  'recall': 0.8,
  'f1-score': 0.8130081300813008,
  'support': 250.0},
 '89': {'precision': 0.7129629629629629,
  'recall': 0.308,
  'f1-score': 0.4301675977653631,
  'support': 250.0},
 '90': {'precision': 0.9153439153439153,
  'recall': 0.692,
  'f1-score': 0.7881548974943052,
  'support': 250.0},
 '91': {'precision': 0.8282442748091603,
  'recall': 0.868,
  'f1-score': 0.84765625,
  'support': 250.0},
 '92': {'precision': 0.6835443037974683,
  'recall': 0.648,
  'f1-score': 0.6652977412731006,
  'support': 250.0},
 '93': {'precision': 0.3114754098360656,
  'recall': 0.38,
  'f1-score': 0.34234234234234234,
  'support': 250.0},
 '94': {'precision': 0.6118721461187214,
  'recall': 0.536,
  'f1-score': 0.5714285714285714,
  'support': 250.0},
 '95': {'precision': 0.7427184466019418,
  'recall': 0.612,
  'f1-score': 0.6710526315789473,
  'support': 250.0},
 '96': {'precision': 0.5625,
  'recall': 0.288,
  'f1-score': 0.38095238095238093,
  'support': 250.0},
 '97': {'precision': 0.4547677261613692,
  'recall': 0.744,
  'f1-score': 0.5644916540212443,
  'support': 250.0},
 '98': {'precision': 0.4685714285714286,
  'recall': 0.328,
  'f1-score': 0.38588235294117645,
  'support': 250.0},
 '99': {'precision': 0.5193798449612403,
  'recall': 0.268,
  'f1-score': 0.35356200527704484,
  'support': 250.0},
 '100': {'precision': 0.5912162162162162,
  'recall': 0.7,
  'f1-score': 0.6410256410256411,
  'support': 250.0},
 'accuracy': 0.6077623762376237,
 'macro avg': {'precision': 0.6328666845830312,
  'recall': 0.6077623762376237,
  'f1-score': 0.6061252197245782,
  'support': 25250.0},
 'weighted avg': {'precision': 0.6328666845830311,
  'recall': 0.6077623762376237,
  'f1-score': 0.6061252197245781,
  'support': 25250.0}}

There's still quite a few values. So we'll narrow it to f1-score, for it's combination of both metrics.

To extract it, we'll need to create an empty dictionary which we'll name class_f1_scores, and then loop function it through each item of classification_report_dict. Appending class name with f1-score as the key.

In [67]:
classification_report_dict.items()
Out[67]:
dict_items([('0', {'precision': 0.29310344827586204, 'recall': 0.204, 'f1-score': 0.24056603773584906, 'support': 250.0}), ('1', {'precision': 0.5088235294117647, 'recall': 0.692, 'f1-score': 0.5864406779661017, 'support': 250.0}), ('2', {'precision': 0.5625, 'recall': 0.648, 'f1-score': 0.6022304832713755, 'support': 250.0}), ('3', {'precision': 0.7415730337078652, 'recall': 0.528, 'f1-score': 0.616822429906542, 'support': 250.0}), ('4', {'precision': 0.7346938775510204, 'recall': 0.432, 'f1-score': 0.5440806045340051, 'support': 250.0}), ('5', {'precision': 0.34177215189873417, 'recall': 0.54, 'f1-score': 0.4186046511627907, 'support': 250.0}), ('6', {'precision': 0.6677966101694915, 'recall': 0.788, 'f1-score': 0.7229357798165138, 'support': 250.0}), ('7', {'precision': 0.8197424892703863, 'recall': 0.764, 'f1-score': 0.7908902691511387, 'support': 250.0}), ('8', {'precision': 0.4025974025974026, 'recall': 0.372, 'f1-score': 0.3866943866943867, 'support': 250.0}), ('9', {'precision': 0.6193181818181818, 'recall': 0.436, 'f1-score': 0.5117370892018779, 'support': 250.0}), ('10', {'precision': 0.6235294117647059, 'recall': 0.424, 'f1-score': 0.5047619047619047, 'support': 250.0}), ('11', {'precision': 0.8356164383561644, 'recall': 0.488, 'f1-score': 0.6161616161616161, 'support': 250.0}), ('12', {'precision': 0.5196629213483146, 'recall': 0.74, 'f1-score': 0.6105610561056105, 'support': 250.0}), ('13', {'precision': 0.5601503759398496, 'recall': 0.596, 'f1-score': 0.5775193798449613, 'support': 250.0}), ('14', {'precision': 0.5584905660377358, 'recall': 0.592, 'f1-score': 0.574757281553398, 'support': 250.0}), ('15', {'precision': 0.4388888888888889, 'recall': 0.316, 'f1-score': 0.3674418604651163, 'support': 250.0}), ('16', {'precision': 0.4530120481927711, 'recall': 0.752, 'f1-score': 0.5654135338345865, 'support': 250.0}), ('17', {'precision': 0.3659942363112392, 'recall': 0.508, 'f1-score': 0.42546063651591287, 'support': 250.0}), ('18', {'precision': 0.4318840579710145, 'recall': 0.596, 'f1-score': 0.5008403361344538, 'support': 250.0}), ('19', {'precision': 0.6832579185520362, 'recall': 0.604, 'f1-score': 0.6411889596602972, 'support': 250.0}), ('20', {'precision': 0.68, 'recall': 0.748, 'f1-score': 0.7123809523809523, 'support': 250.0}), ('21', {'precision': 0.350109409190372, 'recall': 0.64, 'f1-score': 0.4526166902404526, 'support': 250.0}), ('22', {'precision': 0.2977346278317152, 'recall': 0.368, 'f1-score': 0.3291592128801431, 'support': 250.0}), ('23', {'precision': 0.6632302405498282, 'recall': 0.772, 'f1-score': 0.7134935304990758, 'support': 250.0}), ('24', {'precision': 0.8294930875576036, 'recall': 0.72, 'f1-score': 0.7708779443254818, 'support': 250.0}), ('25', {'precision': 0.7574468085106383, 'recall': 0.712, 'f1-score': 0.734020618556701, 'support': 250.0}), ('26', {'precision': 0.5147058823529411, 'recall': 0.42, 'f1-score': 0.46255506607929514, 'support': 250.0}), ('27', {'precision': 0.776824034334764, 'recall': 0.724, 'f1-score': 0.7494824016563147, 'support': 250.0}), ('28', {'precision': 0.6991869918699187, 'recall': 0.688, 'f1-score': 0.6935483870967742, 'support': 250.0}), ('29', {'precision': 0.7024793388429752, 'recall': 0.68, 'f1-score': 0.6910569105691057, 'support': 250.0}), ('30', {'precision': 0.9235294117647059, 'recall': 0.628, 'f1-score': 0.7476190476190476, 'support': 250.0}), ('31', {'precision': 0.7802690582959642, 'recall': 0.696, 'f1-score': 0.7357293868921776, 'support': 250.0}), ('32', {'precision': 0.7472924187725631, 'recall': 0.828, 'f1-score': 0.7855787476280834, 'support': 250.0}), ('33', {'precision': 0.8945454545454545, 'recall': 0.984, 'f1-score': 0.9371428571428572, 'support': 250.0}), ('34', {'precision': 0.6783216783216783, 'recall': 0.776, 'f1-score': 0.7238805970149254, 'support': 250.0}), ('35', {'precision': 0.7819905213270142, 'recall': 0.66, 'f1-score': 0.7158351409978309, 'support': 250.0}), ('36', {'precision': 0.5320754716981132, 'recall': 0.564, 'f1-score': 0.5475728155339806, 'support': 250.0}), ('37', {'precision': 0.29912663755458513, 'recall': 0.548, 'f1-score': 0.3870056497175141, 'support': 250.0}), ('38', {'precision': 0.7772277227722773, 'recall': 0.628, 'f1-score': 0.6946902654867256, 'support': 250.0}), ('39', {'precision': 0.2694805194805195, 'recall': 0.332, 'f1-score': 0.2974910394265233, 'support': 250.0}), ('40', {'precision': 0.7214285714285714, 'recall': 0.808, 'f1-score': 0.7622641509433963, 'support': 250.0}), ('41', {'precision': 0.8115183246073299, 'recall': 0.62, 'f1-score': 0.7029478458049887, 'support': 250.0}), ('42', {'precision': 0.5, 'recall': 0.58, 'f1-score': 0.5370370370370371, 'support': 250.0}), ('43', {'precision': 0.746268656716418, 'recall': 0.6, 'f1-score': 0.6651884700665188, 'support': 250.0}), ('44', {'precision': 0.7417218543046358, 'recall': 0.448, 'f1-score': 0.5586034912718204, 'support': 250.0}), ('45', {'precision': 0.7745454545454545, 'recall': 0.852, 'f1-score': 0.8114285714285714, 'support': 250.0}), ('46', {'precision': 0.8085106382978723, 'recall': 0.456, 'f1-score': 0.5831202046035806, 'support': 250.0}), ('47', {'precision': 0.4392857142857143, 'recall': 0.492, 'f1-score': 0.4641509433962264, 'support': 250.0}), ('48', {'precision': 0.4481236203090508, 'recall': 0.812, 'f1-score': 0.577524893314367, 'support': 250.0}), ('49', {'precision': 0.5045454545454545, 'recall': 0.444, 'f1-score': 0.4723404255319149, 'support': 250.0}), ('50', {'precision': 0.5444444444444444, 'recall': 0.392, 'f1-score': 0.4558139534883721, 'support': 250.0}), ('51', {'precision': 0.7081967213114754, 'recall': 0.864, 'f1-score': 0.7783783783783784, 'support': 250.0}), ('52', {'precision': 0.5092838196286472, 'recall': 0.768, 'f1-score': 0.6124401913875598, 'support': 250.0}), ('53', {'precision': 0.6719367588932806, 'recall': 0.68, 'f1-score': 0.6759443339960238, 'support': 250.0}), ('54', {'precision': 0.8785046728971962, 'recall': 0.752, 'f1-score': 0.8103448275862069, 'support': 250.0}), ('55', {'precision': 0.86, 'recall': 0.688, 'f1-score': 0.7644444444444445, 'support': 250.0}), ('56', {'precision': 0.5596330275229358, 'recall': 0.244, 'f1-score': 0.3398328690807799, 'support': 250.0}), ('57', {'precision': 0.6222222222222222, 'recall': 0.448, 'f1-score': 0.5209302325581395, 'support': 250.0}), ('58', {'precision': 0.6792452830188679, 'recall': 0.576, 'f1-score': 0.6233766233766234, 'support': 250.0}), ('59', {'precision': 0.7045454545454546, 'recall': 0.372, 'f1-score': 0.4869109947643979, 'support': 250.0}), ('60', {'precision': 0.8305084745762712, 'recall': 0.588, 'f1-score': 0.6885245901639344, 'support': 250.0}), ('61', {'precision': 0.543010752688172, 'recall': 0.808, 'f1-score': 0.6495176848874598, 'support': 250.0}), ('62', {'precision': 0.7218934911242604, 'recall': 0.488, 'f1-score': 0.5823389021479713, 'support': 250.0}), ('63', {'precision': 0.9385964912280702, 'recall': 0.856, 'f1-score': 0.895397489539749, 'support': 250.0}), ('64', {'precision': 0.7773722627737226, 'recall': 0.852, 'f1-score': 0.8129770992366412, 'support': 250.0}), ('65', {'precision': 0.82, 'recall': 0.82, 'f1-score': 0.82, 'support': 250.0}), ('66', {'precision': 0.6923076923076923, 'recall': 0.324, 'f1-score': 0.44141689373297005, 'support': 250.0}), ('67', {'precision': 0.4090909090909091, 'recall': 0.576, 'f1-score': 0.47840531561461797, 'support': 250.0}), ('68', {'precision': 0.8981481481481481, 'recall': 0.776, 'f1-score': 0.8326180257510729, 'support': 250.0}), ('69', {'precision': 0.8442622950819673, 'recall': 0.824, 'f1-score': 0.8340080971659919, 'support': 250.0}), ('70', {'precision': 0.6216216216216216, 'recall': 0.828, 'f1-score': 0.7101200686106347, 'support': 250.0}), ('71', {'precision': 0.8111888111888111, 'recall': 0.464, 'f1-score': 0.5903307888040712, 'support': 250.0}), ('72', {'precision': 0.6417322834645669, 'recall': 0.652, 'f1-score': 0.6468253968253969, 'support': 250.0}), ('73', {'precision': 0.5091743119266054, 'recall': 0.444, 'f1-score': 0.47435897435897434, 'support': 250.0}), ('74', {'precision': 0.7169811320754716, 'recall': 0.608, 'f1-score': 0.658008658008658, 'support': 250.0}), ('75', {'precision': 0.8389513108614233, 'recall': 0.896, 'f1-score': 0.8665377176015474, 'support': 250.0}), ('76', {'precision': 0.7777777777777778, 'recall': 0.784, 'f1-score': 0.7808764940239044, 'support': 250.0}), ('77', {'precision': 0.3641304347826087, 'recall': 0.268, 'f1-score': 0.3087557603686636, 'support': 250.0}), ('78', {'precision': 0.7863247863247863, 'recall': 0.736, 'f1-score': 0.7603305785123967, 'support': 250.0}), ('79', {'precision': 0.44130434782608696, 'recall': 0.812, 'f1-score': 0.571830985915493, 'support': 250.0}), ('80', {'precision': 0.5747126436781609, 'recall': 0.6, 'f1-score': 0.5870841487279843, 'support': 250.0}), ('81', {'precision': 0.6529850746268657, 'recall': 0.7, 'f1-score': 0.6756756756756757, 'support': 250.0}), ('82', {'precision': 0.3804878048780488, 'recall': 0.312, 'f1-score': 0.34285714285714286, 'support': 250.0}), ('83', {'precision': 0.5780346820809249, 'recall': 0.8, 'f1-score': 0.6711409395973155, 'support': 250.0}), ('84', {'precision': 0.6103896103896104, 'recall': 0.376, 'f1-score': 0.46534653465346537, 'support': 250.0}), ('85', {'precision': 0.4423076923076923, 'recall': 0.736, 'f1-score': 0.5525525525525525, 'support': 250.0}), ('86', {'precision': 0.7081967213114754, 'recall': 0.864, 'f1-score': 0.7783783783783784, 'support': 250.0}), ('87', {'precision': 0.40756302521008403, 'recall': 0.388, 'f1-score': 0.3975409836065574, 'support': 250.0}), ('88', {'precision': 0.8264462809917356, 'recall': 0.8, 'f1-score': 0.8130081300813008, 'support': 250.0}), ('89', {'precision': 0.7129629629629629, 'recall': 0.308, 'f1-score': 0.4301675977653631, 'support': 250.0}), ('90', {'precision': 0.9153439153439153, 'recall': 0.692, 'f1-score': 0.7881548974943052, 'support': 250.0}), ('91', {'precision': 0.8282442748091603, 'recall': 0.868, 'f1-score': 0.84765625, 'support': 250.0}), ('92', {'precision': 0.6835443037974683, 'recall': 0.648, 'f1-score': 0.6652977412731006, 'support': 250.0}), ('93', {'precision': 0.3114754098360656, 'recall': 0.38, 'f1-score': 0.34234234234234234, 'support': 250.0}), ('94', {'precision': 0.6118721461187214, 'recall': 0.536, 'f1-score': 0.5714285714285714, 'support': 250.0}), ('95', {'precision': 0.7427184466019418, 'recall': 0.612, 'f1-score': 0.6710526315789473, 'support': 250.0}), ('96', {'precision': 0.5625, 'recall': 0.288, 'f1-score': 0.38095238095238093, 'support': 250.0}), ('97', {'precision': 0.4547677261613692, 'recall': 0.744, 'f1-score': 0.5644916540212443, 'support': 250.0}), ('98', {'precision': 0.4685714285714286, 'recall': 0.328, 'f1-score': 0.38588235294117645, 'support': 250.0}), ('99', {'precision': 0.5193798449612403, 'recall': 0.268, 'f1-score': 0.35356200527704484, 'support': 250.0}), ('100', {'precision': 0.5912162162162162, 'recall': 0.7, 'f1-score': 0.6410256410256411, 'support': 250.0}), ('accuracy', 0.6077623762376237), ('macro avg', {'precision': 0.6328666845830312, 'recall': 0.6077623762376237, 'f1-score': 0.6061252197245782, 'support': 25250.0}), ('weighted avg', {'precision': 0.6328666845830311, 'recall': 0.6077623762376237, 'f1-score': 0.6061252197245781, 'support': 25250.0})])
In [76]:
# create empty dictionary
class_f1_scores = {}

# loop through classification report items
for k, v in classification_report_dict.items():
    if k == 'accuracy': # stop once we get to accuracy key of the dictionary - so class report won't catch the non class items beyond int '101'
        break
    else:
        # append class names and f1-scores to new dictionary
        class_f1_scores[class_names[int(k)]] = v['f1-score'] # get classname via k's int value, while extracting 'f1-score' key
class_f1_scores
Out[76]:
{'apple_pie': 0.24056603773584906,
 'baby_back_ribs': 0.5864406779661017,
 'baklava': 0.6022304832713755,
 'beef_carpaccio': 0.616822429906542,
 'beef_tartare': 0.5440806045340051,
 'beet_salad': 0.4186046511627907,
 'beignets': 0.7229357798165138,
 'bibimbap': 0.7908902691511387,
 'bread_pudding': 0.3866943866943867,
 'breakfast_burrito': 0.5117370892018779,
 'bruschetta': 0.5047619047619047,
 'caesar_salad': 0.6161616161616161,
 'cannoli': 0.6105610561056105,
 'caprese_salad': 0.5775193798449613,
 'carrot_cake': 0.574757281553398,
 'ceviche': 0.3674418604651163,
 'cheese_plate': 0.5654135338345865,
 'cheesecake': 0.42546063651591287,
 'chicken_curry': 0.5008403361344538,
 'chicken_quesadilla': 0.6411889596602972,
 'chicken_wings': 0.7123809523809523,
 'chocolate_cake': 0.4526166902404526,
 'chocolate_mousse': 0.3291592128801431,
 'churros': 0.7134935304990758,
 'clam_chowder': 0.7708779443254818,
 'club_sandwich': 0.734020618556701,
 'crab_cakes': 0.46255506607929514,
 'creme_brulee': 0.7494824016563147,
 'croque_madame': 0.6935483870967742,
 'cup_cakes': 0.6910569105691057,
 'deviled_eggs': 0.7476190476190476,
 'donuts': 0.7357293868921776,
 'dumplings': 0.7855787476280834,
 'edamame': 0.9371428571428572,
 'eggs_benedict': 0.7238805970149254,
 'escargots': 0.7158351409978309,
 'falafel': 0.5475728155339806,
 'filet_mignon': 0.3870056497175141,
 'fish_and_chips': 0.6946902654867256,
 'foie_gras': 0.2974910394265233,
 'french_fries': 0.7622641509433963,
 'french_onion_soup': 0.7029478458049887,
 'french_toast': 0.5370370370370371,
 'fried_calamari': 0.6651884700665188,
 'fried_rice': 0.5586034912718204,
 'frozen_yogurt': 0.8114285714285714,
 'garlic_bread': 0.5831202046035806,
 'gnocchi': 0.4641509433962264,
 'greek_salad': 0.577524893314367,
 'grilled_cheese_sandwich': 0.4723404255319149,
 'grilled_salmon': 0.4558139534883721,
 'guacamole': 0.7783783783783784,
 'gyoza': 0.6124401913875598,
 'hamburger': 0.6759443339960238,
 'hot_and_sour_soup': 0.8103448275862069,
 'hot_dog': 0.7644444444444445,
 'huevos_rancheros': 0.3398328690807799,
 'hummus': 0.5209302325581395,
 'ice_cream': 0.6233766233766234,
 'lasagna': 0.4869109947643979,
 'lobster_bisque': 0.6885245901639344,
 'lobster_roll_sandwich': 0.6495176848874598,
 'macaroni_and_cheese': 0.5823389021479713,
 'macarons': 0.895397489539749,
 'miso_soup': 0.8129770992366412,
 'mussels': 0.82,
 'nachos': 0.44141689373297005,
 'omelette': 0.47840531561461797,
 'onion_rings': 0.8326180257510729,
 'oysters': 0.8340080971659919,
 'pad_thai': 0.7101200686106347,
 'paella': 0.5903307888040712,
 'pancakes': 0.6468253968253969,
 'panna_cotta': 0.47435897435897434,
 'peking_duck': 0.658008658008658,
 'pho': 0.8665377176015474,
 'pizza': 0.7808764940239044,
 'pork_chop': 0.3087557603686636,
 'poutine': 0.7603305785123967,
 'prime_rib': 0.571830985915493,
 'pulled_pork_sandwich': 0.5870841487279843,
 'ramen': 0.6756756756756757,
 'ravioli': 0.34285714285714286,
 'red_velvet_cake': 0.6711409395973155,
 'risotto': 0.46534653465346537,
 'samosa': 0.5525525525525525,
 'sashimi': 0.7783783783783784,
 'scallops': 0.3975409836065574,
 'seaweed_salad': 0.8130081300813008,
 'shrimp_and_grits': 0.4301675977653631,
 'spaghetti_bolognese': 0.7881548974943052,
 'spaghetti_carbonara': 0.84765625,
 'spring_rolls': 0.6652977412731006,
 'steak': 0.34234234234234234,
 'strawberry_shortcake': 0.5714285714285714,
 'sushi': 0.6710526315789473,
 'tacos': 0.38095238095238093,
 'takoyaki': 0.5644916540212443,
 'tiramisu': 0.38588235294117645,
 'tuna_tartare': 0.35356200527704484,
 'waffles': 0.6410256410256411}

Looks good!

Seems the dictionary was ordered in alphabetical order. However we can try order them differently.

We can turn class_f1_scores dictionary into pandas DataFrame and sort it in ascending fashion via f1-score.

In [46]:
!x:\miniconda3\envs\tfenv\python -m pip install pandas
Requirement already satisfied: pandas in x:\miniconda3\envs\tfenv\lib\site-packages (2.3.3)
Requirement already satisfied: numpy>=1.22.4 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (1.24.3)
Requirement already satisfied: python-dateutil>=2.8.2 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in x:\miniconda3\envs\tfenv\lib\site-packages (from pandas) (2025.2)
Requirement already satisfied: six>=1.5 in x:\miniconda3\envs\tfenv\lib\site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)
In [77]:
# turn f1-scores into dataframe for visualization
import pandas as pd
f1_scores = pd.DataFrame({'class_names': list(class_f1_scores.keys()),
                          'f1-score': list(class_f1_scores.values())}).sort_values('f1-score', ascending=False)
f1_scores.head()
Out[77]:
class_names f1-score
33 edamame 0.937143
63 macarons 0.895397
75 pho 0.866538
91 spaghetti_carbonara 0.847656
69 oysters 0.834008

Let's finish it off with a horizontal bar chart

In [82]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 25))
scores = ax.barh(range(len(f1_scores)), f1_scores["f1-score"].values)
ax.set_yticks(range(len(f1_scores)))
ax.set_yticklabels(list(f1_scores["class_names"]))
ax.set_xlabel("f1-score")
ax.set_title("F1-Scores for 10 Different Classes")
ax.invert_yaxis(); # reverse the order

def autolabel(rects): # Modified version of: https://matplotlib.org/examples/api/barchart_demo.html
  """
  Attach a text label above each bar displaying its height (it's value).
  """
  for rect in rects:
    width = rect.get_width()
    ax.text(1.03*width, rect.get_y() + rect.get_height()/1.5,
            f"{width:.2f}",
            ha='center', va='bottom')

autolabel(scores)
No description has been provided for this image

Visualizing performance makes a world of a difference. Previously we only had a list upon list of numbers with variable names. Now we get an indication of how well the model predicts different classes.

Findings like the individual performance of classes, allows us to figure out possible next step directions for our experiments. Perhaps we could collect more data of apple_pie or foie_gras for training the worst performing classes, or maybe its just visually difficult to differentiate them amongst other classes.

Exercise: Visualize the 3 worst performing classes, and see if there are any trends or clues to them.

Visualizing predictions on test images¶

We can look at numbers and graphs all we want, but you won't really know how the model performs unless we visually see the model performing on an image.

The model can't predict any image we throw at it. It must first be loaded into a tensor.

To begin, we will create a function to load an image into a tensor:

  • Read in a target filepath using tf.io.read_file().
  • Turn the image into a Tensor using tf.io.decode_image().
  • Resize the image to be the same size as the images our model has been trained on (224x224) using tf.image.resize().
  • Scale the image to get all the pixel values between 0 & 1 if necessary.
In [83]:
def load_and_prep_image(filename, img_shape=224, scale=True):
    '''
    Reads an image from filename, turns it into a tensor and reshapes into
    (224,224,3)

    Parameters:
    filename (str): string filename of target image
    img_shape (int): size to resize target image to. (default set to 224)
    scale (bool): whether to scale pixel values to range 0 to 1 for normalization - default is True
    '''

    # read in the image
    img = tf.io.read_file(filename)
    # decode it into a tensor
    img = tf.io.decode_image(img)
    # resize the image
    img = tf.image.resize(img, [img_shape, img_shape])
    if scale:
        # rescale the image so values are between 0 and 1
        return img/255.
    else:
        return img

Our preprocessing function is complete.

Lets write some code for:

  1. Load a few random images from test dataset.
  2. Make predictions with them
  3. Plot the original image(s) along with the model's predicted label, prediction probability of said label, and ground truth label.
In [103]:
# make predictions on a series of random images
import os
import random

plt.figure(figsize=(17,10))
for i in range(3):
    # choose a random image and random class
    class_name = random.choice(class_names) # pick random class name
    filename = random.choice(os.listdir(test_dir + '/' + class_name)) # get into dir of class name, and pick random filename in dir
    filepath = test_dir + class_name + '/' + filename # create filepath text with our class name and file name

    # load the image and make predictions
    img = load_and_prep_image(filepath, scale=False) # use preprocess function - not using scale, as EfficientNet has done it for us
    pred_prob = model.predict(tf.expand_dims(img, axis=0)) # model's shape is [None,224,224,3] - and must match for it to predict, as `None` is for our batch number
    pred_class = class_names[pred_prob.argmax()] # get max probability value's index location

    # plot the image(s)
    plt.subplot(1,3,i+1)
    plt.imshow(img/255.)
    if class_name == pred_class: # change colour of text on whether it is correct or not
        title_color = 'g'
    else:
        title_color = 'r'
    plt.title(f'actual: {class_name}, pred: {pred_class}, prob: {pred_prob.max():.2f}', c=title_color)
    plt.axis(False);
1/1 [==============================] - 0s 52ms/step
1/1 [==============================] - 0s 67ms/step
1/1 [==============================] - 0s 51ms/step
No description has been provided for this image

Going through multiple re-runs, you can clearly see how the model makes wrong predictions with foods that may look eerily similar to other dishes.

Finding the most wrong predictions¶

It's a good idea to go through 100+ random instances of the model's predictions to get a good idea for how it's doing.

You may notice that the model has high confidence in a certain prediction, but the class turns out to be the wrong one.

These most wrong predictions can help give further insight into the model's performance.

So why not write code to collect all predictions with very high probability (0.95+), but the predicted class was wrong.

We'll go through these steps:

  1. Get all of the image file paths in test dataset using the list_files() method.
  2. Create a pandas DataFrame of the image filepaths, true labels, prediction class, max prediction probability, true label names and predicted class names.
  • Note: We don't necessarily have to create a DataFrame like this, but it'll help with visualization.
  1. Use our DataFrame to find all the wrong predictions (where true label doesn't match to predicted class).
  2. Sort the DataFrame based on wrong predictions and highest max prediction probabilities.
  3. Visualize images with highest prediction probability, but has the wrong prediction.
In [ ]:
# 1. get the filenames of all of the test data
filepaths = []
for filepath in test_data.list_files('101_food_classes_10_percent/test/*/*.jpg',
                                     shuffle=False):
    filepaths.append(filepath.numpy())
filepaths[:10]
Out[ ]:
[b'101_food_classes_10_percent\\test\\apple_pie\\1011328.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\101251.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1034399.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\103801.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1038694.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1047447.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1068632.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\110043.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1106961.jpg',
 b'101_food_classes_10_percent\\test\\apple_pie\\1113017.jpg']

Now we have all test image file paths, let's combine them into a DataFrame along with:

  • true class index (y_label)
  • predicted class index (pred_classes)
  • max probability prediction (pred_probs.max(axis=1))
  • true class name
  • predicted class name
In [109]:
# 2. create a dataframe out of current prediction data for analysis
import pandas as pd
pred_df = pd.DataFrame({'img_path': filepaths,
                        'y_true': y_labels,
                        'y_pred': pred_classes,
                        'pred_conf': pred_probs.max(axis=1), 
                        'y_true_classname': [class_names[i] for i in y_labels],
                        'y_pred_classname': [class_names[i] for i in pred_classes]})
pred_df.head()
Out[109]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname
0 b'101_food_classes_10_percent\\test\\apple_pie... 0 52 0.847418 apple_pie gyoza
1 b'101_food_classes_10_percent\\test\\apple_pie... 0 0 0.964017 apple_pie apple_pie
2 b'101_food_classes_10_percent\\test\\apple_pie... 0 0 0.959259 apple_pie apple_pie
3 b'101_food_classes_10_percent\\test\\apple_pie... 0 80 0.658607 apple_pie pulled_pork_sandwich
4 b'101_food_classes_10_percent\\test\\apple_pie... 0 79 0.367903 apple_pie prime_rib
In [110]:
# 3. use our dataframe to find all wrong predictions
pred_df['pred_correct'] = pred_df['y_true'] == pred_df['y_pred']
pred_df.head()
Out[110]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
0 b'101_food_classes_10_percent\\test\\apple_pie... 0 52 0.847418 apple_pie gyoza False
1 b'101_food_classes_10_percent\\test\\apple_pie... 0 0 0.964017 apple_pie apple_pie True
2 b'101_food_classes_10_percent\\test\\apple_pie... 0 0 0.959259 apple_pie apple_pie True
3 b'101_food_classes_10_percent\\test\\apple_pie... 0 80 0.658607 apple_pie pulled_pork_sandwich False
4 b'101_food_classes_10_percent\\test\\apple_pie... 0 79 0.367903 apple_pie prime_rib False
In [112]:
# 4. sort dataframe based on highes prediction probability
top_100_wrong = pred_df[pred_df['pred_correct'] == False].sort_values('pred_conf', ascending=False)[:100]
top_100_wrong.head(20)
Out[112]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
21810 b'101_food_classes_10_percent\\test\\scallops\... 87 29 0.999997 scallops cup_cakes False
231 b'101_food_classes_10_percent\\test\\apple_pie... 0 100 0.999995 apple_pie waffles False
15359 b'101_food_classes_10_percent\\test\\lobster_r... 61 53 0.999988 lobster_roll_sandwich hamburger False
23539 b'101_food_classes_10_percent\\test\\strawberr... 94 83 0.999987 strawberry_shortcake red_velvet_cake False
21400 b'101_food_classes_10_percent\\test\\samosa\\3... 85 92 0.999981 samosa spring_rolls False
24540 b'101_food_classes_10_percent\\test\\tiramisu\... 98 83 0.999947 tiramisu red_velvet_cake False
2511 b'101_food_classes_10_percent\\test\\bruschett... 10 61 0.999945 bruschetta lobster_roll_sandwich False
5574 b'101_food_classes_10_percent\\test\\chocolate... 22 21 0.999939 chocolate_mousse chocolate_cake False
17855 b'101_food_classes_10_percent\\test\\paella\\2... 71 65 0.999931 paella mussels False
23797 b'101_food_classes_10_percent\\test\\sushi\\16... 95 86 0.999904 sushi sashimi False
18001 b'101_food_classes_10_percent\\test\\pancakes\... 72 67 0.999903 pancakes omelette False
11642 b'101_food_classes_10_percent\\test\\garlic_br... 46 10 0.999877 garlic_bread bruschetta False
10847 b'101_food_classes_10_percent\\test\\fried_cal... 43 68 0.999872 fried_calamari onion_rings False
23631 b'101_food_classes_10_percent\\test\\strawberr... 94 83 0.999858 strawberry_shortcake red_velvet_cake False
1155 b'101_food_classes_10_percent\\test\\beef_tart... 4 5 0.999858 beef_tartare beet_salad False
10854 b'101_food_classes_10_percent\\test\\fried_cal... 43 68 0.999854 fried_calamari onion_rings False
23904 b'101_food_classes_10_percent\\test\\sushi\\33... 95 86 0.999823 sushi sashimi False
7316 b'101_food_classes_10_percent\\test\\cup_cakes... 29 83 0.999817 cup_cakes red_velvet_cake False
13144 b'101_food_classes_10_percent\\test\\gyoza\\31... 52 92 0.999799 gyoza spring_rolls False
10880 b'101_food_classes_10_percent\\test\\fried_cal... 43 68 0.999778 fried_calamari onion_rings False
In [116]:
#5. visualize some of the most wrong examples
images_to_view = 9
start_index = 10 # change the start index to view more
plt.figure(figsize=(15,10))
for i, row in enumerate(top_100_wrong[start_index:start_index+images_to_view].itertuples()):
    plt.subplot(3,3,i+1)
    img = load_and_prep_image(row[1], scale=True)
    _,_,_,_,pred_prob,y_true,y_pred,_= row # only interested in a few parameters from each row
    plt.imshow(img)
    plt.title(f'actual: {y_true}, pred: {y_pred} \nprob: {pred_prob:.2f}')
    plt.axis(False)
No description has been provided for this image

Looking at the most wrong predictions, we can note a full things:

  • Some of the labels might be wrong - If our model ends up being good enough, it may be able to predict so well, and get the right label of a class, that may have been mislabeled in our testing dataset. In that case, we could use the model to help us improve the labeling of our data, in turn making future models better. This is called active learning.

From looking at the top left image, the model predicted omellete, but true class is supposed to be pancake. Just at a glance, we notice a very yellow object that looks roundish and flat in shape. Even though pancakes do involve eggs, they are incorporated into a paste mix consisting of flower and milk. Not a separate side dish to the food.

  • More samples needed - If there's a certain class thats consistently being classified as another class, perhaps it's a good idea to gather more samples of both classes with different scenarios for the model to further improve on.

Looking at the top right and middle right image, the model predicted calamari and onion rings the other way around. And visually seeing both images, tells us how similar these food classes are to each other, which therefore would be a good idea to find more examples to help the model differentiate them.

Test out the big dog model on test images, as well as custom images of food¶

We've visualized some of the model's predictions from test dataset. Noe its time to use our model and predict on our custom images of food.

Let's download and unzip a prepeared folder of third party food images.

In [117]:
!curl -O https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip

unzip_data('custom_food_images.zip')
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  2 12.5M    2  281k    0     0   241k      0  0:00:53  0:00:01  0:00:52  242k
 70 12.5M   70 9118k    0     0  4211k      0  0:00:03  0:00:02  0:00:01 4215k
100 12.5M  100 12.5M    0     0  5063k      0  0:00:02  0:00:02 --:--:-- 5068k

We can load these in, and turn them into tensors with load_and_prep_image() function. But first, we need a list of image filepaths

In [119]:
# get custom food image filepaths
custom_food_images = ['custom_food_images/' + img_path for img_path in os.listdir('custom_food_images')]
custom_food_images
Out[119]:
['custom_food_images/hamburger.jpeg',
 'custom_food_images/steak.jpeg',
 'custom_food_images/sushi.jpeg',
 'custom_food_images/chicken_wings.jpeg',
 'custom_food_images/ramen.jpeg',
 'custom_food_images/pizza-dad.jpeg']

We can now use similar code to what was used previously to load in our images. Make a prediction on each using themodel, and then plot the image alongside it's prediction.

In [121]:
# make predictions on custom food images
for img in custom_food_images:
    img = load_and_prep_image(img, scale=False) # load in target image and turn it into tensor
    pred_prob = model.predict(tf.expand_dims(img, axis=0)) # expand dims for model to put batch num
    pred_class = class_names[pred_prob.argmax()] # find predicted class label

    # plot the image
    plt.figure()
    plt.imshow(img/255.) # requiring float input to be normalized
    plt.title(f'pred: {pred_class}, prob: {pred_prob.max():.2f}')
    plt.axis(False)
1/1 [==============================] - 0s 51ms/step
1/1 [==============================] - 0s 52ms/step
1/1 [==============================] - 0s 52ms/step
1/1 [==============================] - 0s 57ms/step
1/1 [==============================] - 0s 52ms/step
1/1 [==============================] - 0s 51ms/step
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]: